Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.arrow.flight.client;

import java.net.HttpCookie;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.Consumer;
import java.util.stream.Collectors;

Expand All @@ -37,29 +37,31 @@
*/
public class ClientCookieMiddleware implements FlightClientMiddleware {
private static final String SET_COOKIE_HEADER = "Set-Cookie";
private static final String SET_COOKIE2_HEADER = "Set-Cookie2";
private static final String COOKIE_HEADER = "Cookie";
private final Factory factory;

// Use a map to track the most recent version of a cookie from the server.
// Note that cookie names are case-sensitive (but header names aren't).
private Map<String, HttpCookie> cookies = new HashMap<>();

public ClientCookieMiddleware() {
@VisibleForTesting
ClientCookieMiddleware(Factory factory) {
this.factory = factory;
}

/**
* Factory used within FlightClient.
*/
public static class Factory implements FlightClientMiddleware.Factory {
// Use a map to track the most recent version of a cookie from the server.
// Note that cookie names are case-sensitive (but header names aren't).
private ConcurrentMap<String, HttpCookie> cookies = new ConcurrentHashMap<>();

@Override
public ClientCookieMiddleware onCallStarted(CallInfo info) {
return new ClientCookieMiddleware();
return new ClientCookieMiddleware(this);
}
}

@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
final String cookieValue = calculateCookieString();
final String cookieValue = getValidCookiesAsString();
if (!cookieValue.isEmpty()) {
outgoingHeaders.insert(COOKIE_HEADER, cookieValue);
}
Expand All @@ -74,25 +76,30 @@ public void onHeadersReceived(CallHeaders incomingHeaders) {
// to signal that the client should stop using the cookie immediately.
final Consumer<String> handleSetCookieHeader = (headerValue) -> {
final List<HttpCookie> parsedCookies = HttpCookie.parse(headerValue);
parsedCookies.forEach(parsedCookie -> cookies.put(parsedCookie.getName(), parsedCookie));
parsedCookies.forEach(parsedCookie -> factory.cookies.put(parsedCookie.getName(), parsedCookie));
};
incomingHeaders.getAll(SET_COOKIE_HEADER).forEach(handleSetCookieHeader);
incomingHeaders.getAll(SET_COOKIE2_HEADER).forEach(handleSetCookieHeader);
final Iterable<String> setCookieHeaders = incomingHeaders.getAll(SET_COOKIE_HEADER);
if (setCookieHeaders != null) {
setCookieHeaders.forEach(handleSetCookieHeader);
}
}

@Override
public void onCallCompleted(CallStatus status) {

}

/**
* Discards expired cookies and returns the valid cookies as a String delimited by ';'.
*/
@VisibleForTesting
String calculateCookieString() {
String getValidCookiesAsString() {
// Discard expired cookies.
cookies.entrySet().removeIf(cookieEntry -> cookieEntry.getValue().hasExpired());
factory.cookies.entrySet().removeIf(cookieEntry -> cookieEntry.getValue().hasExpired());

// Cookie header value format:
// <cookie-name1>=<cookie-value1>[; <cookie-name2>=<cookie-value2; ...]
return cookies.entrySet().stream()
// [<cookie-name1>=<cookie-value1>; <cookie-name2>=<cookie-value2; ...]
return factory.cookies.entrySet().stream()
.map(cookie -> cookie.getValue().toString())
.collect(Collectors.joining("; "));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,27 @@

package org.apache.arrow.flight.client;

import java.io.IOException;

import org.apache.arrow.flight.CallHeaders;
import org.apache.arrow.flight.CallInfo;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.ErrorFlightMetadata;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightServerMiddleware;
import org.apache.arrow.flight.FlightTestUtil;
import org.apache.arrow.flight.NoOpFlightProducer;
import org.apache.arrow.flight.RequestContext;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;

Expand All @@ -29,37 +46,51 @@
*/
public class TestCookieHandling {
private static final String SET_COOKIE_HEADER = "Set-Cookie";
private static final String SET_COOKIE2_HEADER = "Set-Cookie2";

private ClientCookieMiddleware cookieMiddleware = new ClientCookieMiddleware();
private static final String COOKIE_HEADER = "Cookie";
private BufferAllocator allocator;
private FlightServer server;
private FlightClient client;

private ClientCookieMiddlewareTestFactory testFactory = new ClientCookieMiddlewareTestFactory();
private ClientCookieMiddleware cookieMiddleware = new ClientCookieMiddleware(testFactory);

@Before
public void setup() throws Exception {
allocator = new RootAllocator(Long.MAX_VALUE);
startServerAndClient();
}

@After
public void cleanup() {
cookieMiddleware = new ClientCookieMiddleware();
public void cleanup() throws Exception {
cookieMiddleware = new ClientCookieMiddleware(testFactory);
AutoCloseables.close(client, server, allocator);
client = null;
server = null;
allocator = null;
}

@Test
public void basicCookie() {
CallHeaders headersToSend = new ErrorFlightMetadata();
headersToSend.insert(SET_COOKIE_HEADER, "k=v");
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("k=v", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
}

@Test
public void cookieStaysAfterMultipleRequests() {
CallHeaders headersToSend = new ErrorFlightMetadata();
headersToSend.insert(SET_COOKIE_HEADER, "k=v");
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("k=v", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());

headersToSend = new ErrorFlightMetadata();
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("k=v", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());

headersToSend = new ErrorFlightMetadata();
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("k=v", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());
}

@Test
Expand All @@ -68,19 +99,19 @@ public void cookieAutoExpires() {
headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
cookieMiddleware.onHeadersReceived(headersToSend);
// Note: using max-age changes cookie version from 0->1, which quotes values.
Assert.assertEquals("k=\"v\"", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

headersToSend = new ErrorFlightMetadata();
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("k=\"v\"", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

try {
Thread.sleep(5000);
} catch (InterruptedException ignored) {
}

// Verify that the k cookie was discarded because it expired.
Assert.assertTrue(cookieMiddleware.calculateCookieString().isEmpty());
Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
}

@Test
Expand All @@ -89,14 +120,14 @@ public void cookieExplicitlyExpires() {
headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
cookieMiddleware.onHeadersReceived(headersToSend);
// Note: using max-age changes cookie version from 0->1, which quotes values.
Assert.assertEquals("k=\"v\"", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

headersToSend = new ErrorFlightMetadata();
headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=-2");
cookieMiddleware.onHeadersReceived(headersToSend);

// Verify that the k cookie was discarded because the server told the client it is expired.
Assert.assertTrue(cookieMiddleware.calculateCookieString().isEmpty());
Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
}

@Ignore
Expand All @@ -106,7 +137,7 @@ public void cookieExplicitlyExpiresWithMaxAgeMinusOne() {
headersToSend.insert(SET_COOKIE_HEADER, "k=v; Max-Age=2");
cookieMiddleware.onHeadersReceived(headersToSend);
// Note: using max-age changes cookie version from 0->1, which quotes values.
Assert.assertEquals("k=\"v\"", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString());

headersToSend = new ErrorFlightMetadata();

Expand All @@ -116,20 +147,20 @@ public void cookieExplicitlyExpiresWithMaxAgeMinusOne() {
cookieMiddleware.onHeadersReceived(headersToSend);

// Verify that the k cookie was discarded because the server told the client it is expired.
Assert.assertTrue(cookieMiddleware.calculateCookieString().isEmpty());
Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty());
}

@Test
public void changeCookieValue() {
CallHeaders headersToSend = new ErrorFlightMetadata();
headersToSend.insert(SET_COOKIE_HEADER, "k=v");
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("k=v", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString());

headersToSend = new ErrorFlightMetadata();
headersToSend.insert(SET_COOKIE_HEADER, "k=v2");
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("k=v2", cookieMiddleware.calculateCookieString());
Assert.assertEquals("k=v2", cookieMiddleware.getValidCookiesAsString());
}

@Test
Expand All @@ -138,27 +169,84 @@ public void multipleCookiesWithSetCookie() {
headersToSend.insert(SET_COOKIE_HEADER, "firstKey=firstVal");
headersToSend.insert(SET_COOKIE_HEADER, "secondKey=secondVal");
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.calculateCookieString());
Assert.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.getValidCookiesAsString());
}

@Test
public void basicCookiesWithSetCookie2() {
CallHeaders headersToSend = new ErrorFlightMetadata();
headersToSend.insert(SET_COOKIE2_HEADER, "firstKey=firstVal");
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("firstKey=firstVal", cookieMiddleware.calculateCookieString());
public void cookieStaysAfterMultipleRequestsEndToEnd() {
client.handshake();
Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
client.handshake();
Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
client.listFlights(Criteria.ALL);
Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString());
}

@Ignore
@Test
public void multipleCookiesWithSetCookie2() {
// There seems to be a JDK bug with HttpCookie.parse() with multiple cookies
// in a Set-Cookie2 header. This is odd, because that method explictly returns a list
// of cookies because of Set-Cookie2.
// Set-Cookie2 itself is deprecated.
CallHeaders headersToSend = new ErrorFlightMetadata();
headersToSend.insert(SET_COOKIE2_HEADER, "firstKey=firstVal, secondKey=secondVal");
cookieMiddleware.onHeadersReceived(headersToSend);
Assert.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.calculateCookieString());
/**
* A server middleware component that injects SET_COOKIE_HEADER into the outgoing headers.
*/
static class SetCookieHeaderInjector implements FlightServerMiddleware {
private final Factory factory;

public SetCookieHeaderInjector(Factory factory) {
this.factory = factory;
}

@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
if (!factory.receivedCookieHeader) {
outgoingHeaders.insert(SET_COOKIE_HEADER, "k=v");
}
}

@Override
public void onCallCompleted(CallStatus status) {

}

@Override
public void onCallErrored(Throwable err) {

}

static class Factory implements FlightServerMiddleware.Factory<SetCookieHeaderInjector> {
private boolean receivedCookieHeader = false;

@Override
public SetCookieHeaderInjector onCallStarted(CallInfo info, CallHeaders incomingHeaders,
RequestContext context) {
receivedCookieHeader = null != incomingHeaders.get(COOKIE_HEADER);
return new SetCookieHeaderInjector(this);
}
}
}

public static class ClientCookieMiddlewareTestFactory extends ClientCookieMiddleware.Factory {

private ClientCookieMiddleware clientCookieMiddleware;

@Override
public ClientCookieMiddleware onCallStarted(CallInfo info) {
this.clientCookieMiddleware = new ClientCookieMiddleware(this);
return this.clientCookieMiddleware;
}
}

private void startServerAndClient() throws IOException {
final FlightProducer flightProducer = new NoOpFlightProducer() {
public void listFlights(CallContext context, Criteria criteria,
StreamListener<FlightInfo> listener) {
listener.onCompleted();
}
};

this.server = FlightTestUtil.getStartedServer((location) -> FlightServer
.builder(allocator, location, flightProducer)
.middleware(FlightServerMiddleware.Key.of("test"), new SetCookieHeaderInjector.Factory())
.build());

this.client = FlightClient.builder(allocator, server.getLocation())
.intercept(testFactory)
.build();
}
}