diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java new file mode 100644 index 00000000000..dd26d190872 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightCallHeaders.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.Collection; +import java.util.Set; +import java.util.stream.Collectors; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Iterables; +import com.google.common.collect.Multimap; + +import io.grpc.Metadata; + +/** + * An implementation of the Flight headers interface for headers. + */ +public class FlightCallHeaders implements CallHeaders { + private final Multimap keysAndValues; + + public FlightCallHeaders() { + this.keysAndValues = ArrayListMultimap.create(); + } + + @Override + public String get(String key) { + final Collection values = this.keysAndValues.get(key); + if (values.isEmpty()) { + return null; + } + + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return new String((byte[]) Iterables.get(values, 0)); + } + + return (String) Iterables.get(values, 0); + } + + @Override + public byte[] getByte(String key) { + final Collection values = this.keysAndValues.get(key); + if (values.isEmpty()) { + return null; + } + + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return (byte[]) Iterables.get(values, 0); + } + + return ((String) Iterables.get(values, 0)).getBytes(); + } + + @Override + public Iterable getAll(String key) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return this.keysAndValues.get(key).stream().map(o -> new String((byte[]) o)).collect(Collectors.toList()); + } + return (Collection) (Collection) this.keysAndValues.get(key); + } + + @Override + public Iterable getAllByte(String key) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return (Collection) (Collection) this.keysAndValues.get(key); + } + return this.keysAndValues.get(key).stream().map(o -> ((String) o).getBytes()).collect(Collectors.toList()); + } + + @Override + public void insert(String key, String value) { + this.keysAndValues.put(key, value); + } + + @Override + public void insert(String key, byte[] value) { + Preconditions.checkArgument(key.endsWith("-bin"), "Binary header is named %s. It must end with %s", key, "-bin"); + Preconditions.checkArgument(key.length() > "-bin".length(), "empty key name"); + + this.keysAndValues.put(key, value); + } + + @Override + public Set keys() { + return this.keysAndValues.keySet(); + } + + @Override + public boolean containsKey(String key) { + return this.keysAndValues.containsKey(key); + } + + public String toString() { + return this.keysAndValues.toString(); + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java index 2ea8cc7e344..2d039c9d24e 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightConstants.java @@ -24,4 +24,6 @@ public interface FlightConstants { String SERVICE = "arrow.flight.protocol.FlightService"; + FlightServerMiddleware.Key HEADER_KEY = + FlightServerMiddleware.Key.of("org.apache.arrow.flight.ServerHeaderMiddleware"); } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java index ee62ee5eb83..d59480bfb0a 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightServer.java @@ -196,6 +196,9 @@ public FlightServer build() { this.middleware(FlightServerMiddleware.Key.of(Auth2Constants.AUTHORIZATION_HEADER), new ServerCallHeaderAuthMiddleware.Factory(headerAuthenticator)); } + + this.middleware(FlightConstants.HEADER_KEY, new ServerHeaderMiddleware.Factory()); + final NettyServerBuilder builder; switch (location.getUri().getScheme()) { case LocationSchemes.GRPC_DOMAIN_SOCKET: { diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java new file mode 100644 index 00000000000..e2fad1a402d --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/HeaderCallOption.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import io.grpc.Metadata; +import io.grpc.stub.AbstractStub; +import io.grpc.stub.MetadataUtils; + +/** + * Method option for supplying headers to method calls. + */ +public class HeaderCallOption implements CallOptions.GrpcCallOption { + private final Metadata propertiesMetadata = new Metadata(); + + /** + * Header property constructor. + * + * @param headers the headers that should be sent across. If a header is a string, it should only be valid ASCII + * characters. Binary headers should end in "-bin". + */ + public HeaderCallOption(CallHeaders headers) { + for (String key : headers.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + final Metadata.Key metaKey = Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER); + headers.getAllByte(key).forEach(v -> propertiesMetadata.put(metaKey, v)); + } else { + final Metadata.Key metaKey = Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER); + headers.getAll(key).forEach(v -> propertiesMetadata.put(metaKey, v)); + } + } + } + + @Override + public > T wrapStub(T stub) { + return MetadataUtils.attachHeaders(stub, propertiesMetadata); + } +} diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java new file mode 100644 index 00000000000..527c3128c65 --- /dev/null +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/ServerHeaderMiddleware.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +/** + * Middleware that's used to extract and pass headers to the server during requests. + */ +public class ServerHeaderMiddleware implements FlightServerMiddleware { + /** + * Factory for accessing ServerHeaderMiddleware. + */ + public static class Factory implements FlightServerMiddleware.Factory { + /** + * Construct a factory for receiving call headers. + */ + public Factory() { + } + + @Override + public ServerHeaderMiddleware onCallStarted(CallInfo callInfo, CallHeaders incomingHeaders, + RequestContext context) { + return new ServerHeaderMiddleware(incomingHeaders); + } + } + + private final CallHeaders headers; + + private ServerHeaderMiddleware(CallHeaders incomingHeaders) { + this.headers = incomingHeaders; + } + + /** + * Retrieve the headers for this call. + */ + public CallHeaders headers() { + return headers; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { + } + + @Override + public void onCallCompleted(CallStatus status) { + } + + @Override + public void onCallErrored(Throwable err) { + } +} diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java index 3acb9473006..45e3e496092 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java @@ -30,6 +30,8 @@ import org.junit.Ignore; import org.junit.Test; +import io.grpc.Metadata; + public class TestCallOptions { @Test @@ -64,10 +66,63 @@ public void underTimeout() { }); } + @Test + public void singleProperty() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key", "value"); + testHeaders(headers); + } + + @Test + public void multipleProperties() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key", "value"); + headers.insert("key2", "value2"); + testHeaders(headers); + } + + @Test + public void binaryProperties() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key-bin", "value".getBytes()); + headers.insert("key3-bin", "ëfßæ".getBytes()); + testHeaders(headers); + } + + @Test + public void mixedProperties() { + final FlightCallHeaders headers = new FlightCallHeaders(); + headers.insert("key", "value"); + headers.insert("key3-bin", "ëfßæ".getBytes()); + testHeaders(headers); + } + + private void testHeaders(CallHeaders headers) { + try ( + BufferAllocator a = new RootAllocator(Long.MAX_VALUE); + HeaderProducer producer = new HeaderProducer(); + FlightServer s = + FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()); + FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { + client.doAction(new Action(""), new HeaderCallOption(headers)).hasNext(); + + final CallHeaders incomingHeaders = producer.headers(); + for (String key : headers.keys()) { + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + Assert.assertArrayEquals(headers.getByte(key), incomingHeaders.getByte(key)); + } else { + Assert.assertEquals(headers.get(key), incomingHeaders.get(key)); + } + } + } catch (InterruptedException | IOException e) { + throw new RuntimeException(e); + } + } + void test(Consumer testFn) { try ( BufferAllocator a = new RootAllocator(Long.MAX_VALUE); - Producer producer = new Producer(a); + Producer producer = new Producer(); FlightServer s = FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()); FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { @@ -77,12 +132,27 @@ void test(Consumer testFn) { } } - static class Producer extends NoOpFlightProducer implements AutoCloseable { + static class HeaderProducer extends NoOpFlightProducer implements AutoCloseable { + CallHeaders headers; - private final BufferAllocator allocator; + @Override + public void close() { + } + + public CallHeaders headers() { + return headers; + } + + @Override + public void doAction(CallContext context, Action action, StreamListener listener) { + this.headers = context.getMiddleware(FlightConstants.HEADER_KEY).headers(); + listener.onCompleted(); + } + } + + static class Producer extends NoOpFlightProducer implements AutoCloseable { - Producer(BufferAllocator allocator) { - this.allocator = allocator; + Producer() { } @Override