Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight;

import java.util.Collection;
import java.util.Set;
import java.util.stream.Collectors;

import com.google.common.base.Preconditions;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Multimap;

import io.grpc.Metadata;

/**
* An implementation of the Flight headers interface for headers.
*/
public class FlightCallHeaders implements CallHeaders {
private final Multimap<String, Object> keysAndValues;

public FlightCallHeaders() {
this.keysAndValues = ArrayListMultimap.create();
}

@Override
public String get(String key) {
final Collection<Object> values = this.keysAndValues.get(key);
if (values.isEmpty()) {
return null;
}

if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
return new String((byte[]) Iterables.get(values, 0));
}

return (String) Iterables.get(values, 0);
}

@Override
public byte[] getByte(String key) {
final Collection<Object> values = this.keysAndValues.get(key);
if (values.isEmpty()) {
return null;
}

if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
return (byte[]) Iterables.get(values, 0);
}

return ((String) Iterables.get(values, 0)).getBytes();
}

@Override
public Iterable<String> getAll(String key) {
if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
return this.keysAndValues.get(key).stream().map(o -> new String((byte[]) o)).collect(Collectors.toList());
}
return (Collection<String>) (Collection<?>) this.keysAndValues.get(key);
}

@Override
public Iterable<byte[]> getAllByte(String key) {
if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
return (Collection<byte[]>) (Collection<?>) this.keysAndValues.get(key);
}
return this.keysAndValues.get(key).stream().map(o -> ((String) o).getBytes()).collect(Collectors.toList());
}

@Override
public void insert(String key, String value) {
this.keysAndValues.put(key, value);
}

@Override
public void insert(String key, byte[] value) {
Preconditions.checkArgument(key.endsWith("-bin"), "Binary header is named %s. It must end with %s", key, "-bin");
Preconditions.checkArgument(key.length() > "-bin".length(), "empty key name");

this.keysAndValues.put(key, value);
}

@Override
public Set<String> keys() {
return this.keysAndValues.keySet();
}

@Override
public boolean containsKey(String key) {
return this.keysAndValues.containsKey(key);
}

public String toString() {
return this.keysAndValues.toString();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ public interface FlightConstants {

String SERVICE = "arrow.flight.protocol.FlightService";

FlightServerMiddleware.Key<ServerHeaderMiddleware> HEADER_KEY =
FlightServerMiddleware.Key.of("org.apache.arrow.flight.ServerHeaderMiddleware");
}
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ public FlightServer build() {
this.middleware(FlightServerMiddleware.Key.of(Auth2Constants.AUTHORIZATION_HEADER),
new ServerCallHeaderAuthMiddleware.Factory(headerAuthenticator));
}

this.middleware(FlightConstants.HEADER_KEY, new ServerHeaderMiddleware.Factory());

final NettyServerBuilder builder;
switch (location.getUri().getScheme()) {
case LocationSchemes.GRPC_DOMAIN_SOCKET: {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight;

import io.grpc.Metadata;
import io.grpc.stub.AbstractStub;
import io.grpc.stub.MetadataUtils;

/**
* Method option for supplying headers to method calls.
*/
public class HeaderCallOption implements CallOptions.GrpcCallOption {
private final Metadata propertiesMetadata = new Metadata();

/**
* Header property constructor.
*
* @param headers the headers that should be sent across. If a header is a string, it should only be valid ASCII
* characters. Binary headers should end in "-bin".
*/
public HeaderCallOption(CallHeaders headers) {
for (String key : headers.keys()) {
if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
final Metadata.Key<byte[]> metaKey = Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER);
headers.getAllByte(key).forEach(v -> propertiesMetadata.put(metaKey, v));
} else {
final Metadata.Key<String> metaKey = Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER);
headers.getAll(key).forEach(v -> propertiesMetadata.put(metaKey, v));
}
}
}

@Override
public <T extends AbstractStub<T>> T wrapStub(T stub) {
return MetadataUtils.attachHeaders(stub, propertiesMetadata);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight;

/**
* Middleware that's used to extract and pass headers to the server during requests.
*/
public class ServerHeaderMiddleware implements FlightServerMiddleware {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename ServerHeaderMiddleware to ServerPropertyMiddleware? The current name is a bit too generic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the name is appropriate given that this is intended to handle all headers for calls, which could be considered properties depending on the FlightServer implementation, but not necessarily.

/**
* Factory for accessing ServerHeaderMiddleware.
*/
public static class Factory implements FlightServerMiddleware.Factory<ServerHeaderMiddleware> {
/**
* Construct a factory for receiving call headers.
*/
public Factory() {
}

@Override
public ServerHeaderMiddleware onCallStarted(CallInfo callInfo, CallHeaders incomingHeaders,
RequestContext context) {
return new ServerHeaderMiddleware(incomingHeaders);
}
}

private final CallHeaders headers;

private ServerHeaderMiddleware(CallHeaders incomingHeaders) {
this.headers = incomingHeaders;
}

/**
* Retrieve the headers for this call.
*/
public CallHeaders headers() {
return headers;
}

@Override
public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) {
}

@Override
public void onCallCompleted(CallStatus status) {
}

@Override
public void onCallErrored(Throwable err) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import org.junit.Ignore;
import org.junit.Test;

import io.grpc.Metadata;

public class TestCallOptions {

@Test
Expand Down Expand Up @@ -64,10 +66,63 @@ public void underTimeout() {
});
}

@Test
public void singleProperty() {
final FlightCallHeaders headers = new FlightCallHeaders();
headers.insert("key", "value");
testHeaders(headers);
}

@Test
public void multipleProperties() {
final FlightCallHeaders headers = new FlightCallHeaders();
headers.insert("key", "value");
headers.insert("key2", "value2");
testHeaders(headers);
}

@Test
public void binaryProperties() {
final FlightCallHeaders headers = new FlightCallHeaders();
headers.insert("key-bin", "value".getBytes());
headers.insert("key3-bin", "ëfßæ".getBytes());
testHeaders(headers);
}

@Test
public void mixedProperties() {
final FlightCallHeaders headers = new FlightCallHeaders();
headers.insert("key", "value");
headers.insert("key3-bin", "ëfßæ".getBytes());
testHeaders(headers);
}

private void testHeaders(CallHeaders headers) {
try (
BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
HeaderProducer producer = new HeaderProducer();
FlightServer s =
FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build());
FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
client.doAction(new Action(""), new HeaderCallOption(headers)).hasNext();

final CallHeaders incomingHeaders = producer.headers();
for (String key : headers.keys()) {
if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) {
Assert.assertArrayEquals(headers.getByte(key), incomingHeaders.getByte(key));
} else {
Assert.assertEquals(headers.get(key), incomingHeaders.get(key));
}
}
} catch (InterruptedException | IOException e) {
throw new RuntimeException(e);
}
}

void test(Consumer<FlightClient> testFn) {
try (
BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
Producer producer = new Producer(a);
Producer producer = new Producer();
FlightServer s =
FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build());
FlightClient client = FlightClient.builder(a, s.getLocation()).build()) {
Expand All @@ -77,12 +132,27 @@ void test(Consumer<FlightClient> testFn) {
}
}

static class Producer extends NoOpFlightProducer implements AutoCloseable {
static class HeaderProducer extends NoOpFlightProducer implements AutoCloseable {
CallHeaders headers;

private final BufferAllocator allocator;
@Override
public void close() {
}

public CallHeaders headers() {
return headers;
}

@Override
public void doAction(CallContext context, Action action, StreamListener<Result> listener) {
this.headers = context.getMiddleware(FlightConstants.HEADER_KEY).headers();
listener.onCompleted();
}
}

static class Producer extends NoOpFlightProducer implements AutoCloseable {

Producer(BufferAllocator allocator) {
this.allocator = allocator;
Producer() {
}

@Override
Expand Down