diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java new file mode 100644 index 00000000000..56f24e1019c --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/client/ClientCookieMiddleware.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.client; + +import java.net.HttpCookie; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.util.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A client middleware for receiving and sending cookie information. + * Note that this class will not persist permanent cookies beyond the lifetime + * of this session. + * + * This middleware will automatically remove cookies that have expired. + * Note: Negative max-age values currently do not get marked as expired due to + * a JDK issue. Use max-age=0 to explicitly remove an existing cookie. + */ +public class ClientCookieMiddleware implements FlightClientMiddleware { + private static final Logger LOGGER = LoggerFactory.getLogger(ClientCookieMiddleware.class); + + private static final String SET_COOKIE_HEADER = "Set-Cookie"; + private static final String COOKIE_HEADER = "Cookie"; + + private final Factory factory; + + @VisibleForTesting + ClientCookieMiddleware(Factory factory) { + this.factory = factory; + } + + /** + * Factory used within FlightClient. + */ + public static class Factory implements FlightClientMiddleware.Factory { + // Use a map to track the most recent version of a cookie from the server. + // Note that cookie names are case-sensitive (but header names aren't). + private ConcurrentMap cookies = new ConcurrentHashMap<>(); + + @Override + public ClientCookieMiddleware onCallStarted(CallInfo info) { + return new ClientCookieMiddleware(this); + } + + private void updateCookies(Iterable newCookieHeaderValues) { + // Note: Intentionally overwrite existing cookie values. + // A cookie defined once will continue to be used in all subsequent + // requests on the client instance. The server can send the same cookie again + // with a different value and the client will use the new value in future requests. + // The server can also update a cookie to have an Expiry in the past or negative age + // to signal that the client should stop using the cookie immediately. + newCookieHeaderValues.forEach(headerValue -> { + try { + final List parsedCookies = HttpCookie.parse(headerValue); + parsedCookies.forEach(parsedCookie -> { + final String cookieNameLc = parsedCookie.getName().toLowerCase(Locale.ENGLISH); + if (parsedCookie.hasExpired()) { + cookies.remove(cookieNameLc); + } else { + cookies.put(parsedCookie.getName().toLowerCase(Locale.ENGLISH), parsedCookie); + } + }); + } catch (IllegalArgumentException ex) { + LOGGER.warn("Skipping incorrectly formatted Set-Cookie header with value '{}'.", headerValue); + } + }); + } + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + final String cookieValue = getValidCookiesAsString(); + if (!cookieValue.isEmpty()) { + outgoingHeaders.insert(COOKIE_HEADER, cookieValue); + } + } + + @Override + public void onHeadersReceived(CallHeaders incomingHeaders) { + final Iterable setCookieHeaders = incomingHeaders.getAll(SET_COOKIE_HEADER); + if (setCookieHeaders != null) { + factory.updateCookies(setCookieHeaders); + } + } + + @Override + public void onCallCompleted(CallStatus status) { + + } + + /** + * Discards expired cookies and returns the valid cookies as a String delimited by ';'. + */ + @VisibleForTesting + String getValidCookiesAsString() { + // Discard expired cookies. + factory.cookies.entrySet().removeIf(cookieEntry -> cookieEntry.getValue().hasExpired()); + + // Cookie header value format: + // [=; = cookie.getValue().toString()) + .collect(Collectors.joining("; ")); + } +} diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java new file mode 100644 index 00000000000..f205f9a3b63 --- /dev/null +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.client; + +import java.io.IOException; + +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.ErrorFlightMetadata; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightMethod; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.FlightTestUtil; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.RequestContext; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +/** + * Tests for correct handling of cookies from the FlightClient using {@link ClientCookieMiddleware}. + */ +public class TestCookieHandling { + private static final String SET_COOKIE_HEADER = "Set-Cookie"; + private static final String COOKIE_HEADER = "Cookie"; + private BufferAllocator allocator; + private FlightServer server; + private FlightClient client; + + private ClientCookieMiddlewareTestFactory testFactory = new ClientCookieMiddlewareTestFactory(); + private ClientCookieMiddleware cookieMiddleware = new ClientCookieMiddleware(testFactory); + + @Before + public void setup() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + startServerAndClient(); + } + + @After + public void cleanup() throws Exception { + testFactory = new ClientCookieMiddlewareTestFactory(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + AutoCloseables.close(client, server, allocator); + client = null; + server = null; + allocator = null; + } + + @Test + public void basicCookie() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + } + + @Test + public void cookieStaysAfterMultipleRequests() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + } + + @Ignore + @Test + public void cookieAutoExpires() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + // Note: using max-age changes cookie version from 0->1, which quotes values. + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + try { + Thread.sleep(5000); + } catch (InterruptedException ignored) { + } + + // Verify that the k cookie was discarded because it expired. + Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + } + + @Test + public void cookieExplicitlyExpires() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + // Note: using max-age changes cookie version from 0->1, which quotes values. + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + // Note: The JDK treats Max-Age < 0 as not expired and treats 0 as expired. + // This violates the RFC, which states that less than zero and zero should both be expired. + headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=0"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + + // Verify that the k cookie was discarded because the server told the client it is expired. + Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + } + + @Ignore + @Test + public void cookieExplicitlyExpiresWithMaxAgeMinusOne() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + // Note: using max-age changes cookie version from 0->1, which quotes values. + Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + + // The Java HttpCookie class has a bug where it uses a -1 maxAge to indicate + // a persistent cookie, when the RFC spec says this should mean the cookie expires immediately. + headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=-1"); + cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); + cookieMiddleware.onHeadersReceived(headersToSend); + + // Verify that the k cookie was discarded because the server told the client it is expired. + Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + } + + @Test + public void changeCookieValue() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v"); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + + headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "k=v2"); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("k=v2", cookieMiddleware.getValidCookiesAsString()); + } + + @Test + public void multipleCookiesWithSetCookie() { + CallHeaders headersToSend = new ErrorFlightMetadata(); + headersToSend.insert(SET_COOKIE_HEADER, "firstKey=firstVal"); + headersToSend.insert(SET_COOKIE_HEADER, "secondKey=secondVal"); + cookieMiddleware.onHeadersReceived(headersToSend); + Assert.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.getValidCookiesAsString()); + } + + @Test + public void cookieStaysAfterMultipleRequestsEndToEnd() { + client.handshake(); + Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + client.handshake(); + Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + client.listFlights(Criteria.ALL); + Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + } + + /** + * A server middleware component that injects SET_COOKIE_HEADER into the outgoing headers. + */ + static class SetCookieHeaderInjector implements FlightServerMiddleware { + private final Factory factory; + + public SetCookieHeaderInjector(Factory factory) { + this.factory = factory; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + if (!factory.receivedCookieHeader) { + outgoingHeaders.insert(SET_COOKIE_HEADER, "k=v"); + } + } + + @Override + public void onCallCompleted(CallStatus status) { + + } + + @Override + public void onCallErrored(Throwable err) { + + } + + static class Factory implements FlightServerMiddleware.Factory { + private boolean receivedCookieHeader = false; + + @Override + public SetCookieHeaderInjector onCallStarted(CallInfo info, CallHeaders incomingHeaders, + RequestContext context) { + receivedCookieHeader = null != incomingHeaders.get(COOKIE_HEADER); + return new SetCookieHeaderInjector(this); + } + } + } + + public static class ClientCookieMiddlewareTestFactory extends ClientCookieMiddleware.Factory { + + private ClientCookieMiddleware clientCookieMiddleware; + + @Override + public ClientCookieMiddleware onCallStarted(CallInfo info) { + this.clientCookieMiddleware = new ClientCookieMiddleware(this); + return this.clientCookieMiddleware; + } + } + + private void startServerAndClient() throws IOException { + final FlightProducer flightProducer = new NoOpFlightProducer() { + public void listFlights(CallContext context, Criteria criteria, + StreamListener listener) { + listener.onCompleted(); + } + }; + + this.server = FlightTestUtil.getStartedServer((location) -> FlightServer + .builder(allocator, location, flightProducer) + .middleware(FlightServerMiddleware.Key.of("test"), new SetCookieHeaderInjector.Factory()) + .build()); + + this.client = FlightClient.builder(allocator, server.getLocation()) + .intercept(testFactory) + .build(); + } +}