Skip to content

Commit 4846efd

Browse files
committed
GH-81: [Flight] Expose gRPC in Flight client builder
Fixes #81.
1 parent d304da5 commit 4846efd

File tree

3 files changed

+267
-137
lines changed

3 files changed

+267
-137
lines changed

flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java

Lines changed: 21 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,13 @@
2323
import io.grpc.ManagedChannel;
2424
import io.grpc.MethodDescriptor;
2525
import io.grpc.StatusRuntimeException;
26-
import io.grpc.netty.GrpcSslContexts;
2726
import io.grpc.netty.NettyChannelBuilder;
2827
import io.grpc.stub.ClientCallStreamObserver;
2928
import io.grpc.stub.ClientCalls;
3029
import io.grpc.stub.ClientResponseObserver;
3130
import io.grpc.stub.StreamObserver;
32-
import io.netty.channel.EventLoopGroup;
33-
import io.netty.channel.ServerChannel;
34-
import io.netty.handler.ssl.SslContextBuilder;
35-
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
3631
import java.io.IOException;
3732
import java.io.InputStream;
38-
import java.lang.reflect.InvocationTargetException;
3933
import java.net.URISyntaxException;
4034
import java.nio.ByteBuffer;
4135
import java.util.ArrayList;
@@ -45,7 +39,6 @@
4539
import java.util.concurrent.ExecutionException;
4640
import java.util.concurrent.TimeUnit;
4741
import java.util.function.BooleanSupplier;
48-
import javax.net.ssl.SSLException;
4942
import org.apache.arrow.flight.FlightProducer.StreamListener;
5043
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
5144
import org.apache.arrow.flight.auth.ClientAuthHandler;
@@ -57,6 +50,7 @@
5750
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
5851
import org.apache.arrow.flight.grpc.ClientInterceptorAdapter;
5952
import org.apache.arrow.flight.grpc.CredentialCallOption;
53+
import org.apache.arrow.flight.grpc.NettyClientBuilder;
6054
import org.apache.arrow.flight.grpc.StatusUtils;
6155
import org.apache.arrow.flight.impl.Flight;
6256
import org.apache.arrow.flight.impl.Flight.Empty;
@@ -73,12 +67,6 @@
7367
public class FlightClient implements AutoCloseable {
7468
private static final int PENDING_REQUESTS = 5;
7569

76-
/**
77-
* The maximum number of trace events to keep on the gRPC Channel. This value disables channel
78-
* tracing.
79-
*/
80-
private static final int MAX_CHANNEL_TRACE_EVENTS = 0;
81-
8270
private final BufferAllocator allocator;
8371
private final ManagedChannel channel;
8472

@@ -97,11 +85,12 @@ public class FlightClient implements AutoCloseable {
9785
List<FlightClientMiddleware.Factory> middleware) {
9886
this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
9987
this.channel = channel;
100-
this.middleware = middleware;
88+
// We need a mutable copy (shared between this class and ClientInterceptorAdapter)
89+
this.middleware = new ArrayList<>(middleware);
10190

10291
final ClientInterceptor[] interceptors;
10392
interceptors =
104-
new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(middleware)};
93+
new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(this.middleware)};
10594

10695
// Create a channel with interceptors pre-applied for DoGet and DoPut
10796
Channel interceptedChannel = ClientInterceptors.intercept(channel, interceptors);
@@ -772,176 +761,71 @@ public static Builder builder(BufferAllocator allocator, Location location) {
772761

773762
/** A builder for Flight clients. */
774763
public static final class Builder {
775-
private BufferAllocator allocator;
776-
private Location location;
777-
private boolean forceTls = false;
778-
private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE;
779-
private InputStream trustedCertificates = null;
780-
private InputStream clientCertificate = null;
781-
private InputStream clientKey = null;
782-
private String overrideHostname = null;
783-
private List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
784-
private boolean verifyServer = true;
785-
786-
private Builder() {}
764+
private final NettyClientBuilder builder;
765+
766+
private Builder() {
767+
this.builder = new NettyClientBuilder();
768+
}
787769

788770
private Builder(BufferAllocator allocator, Location location) {
789-
this.allocator = Preconditions.checkNotNull(allocator);
790-
this.location = Preconditions.checkNotNull(location);
771+
this.builder = new NettyClientBuilder(allocator, location);
791772
}
792773

793774
/** Force the client to connect over TLS. */
794775
public Builder useTls() {
795-
this.forceTls = true;
776+
builder.useTls();
796777
return this;
797778
}
798779

799780
/** Override the hostname checked for TLS. Use with caution in production. */
800781
public Builder overrideHostname(final String hostname) {
801-
this.overrideHostname = hostname;
782+
builder.overrideHostname(hostname);
802783
return this;
803784
}
804785

805786
/** Set the maximum inbound message size. */
806787
public Builder maxInboundMessageSize(int maxSize) {
807-
Preconditions.checkArgument(maxSize > 0);
808-
this.maxInboundMessageSize = maxSize;
788+
builder.maxInboundMessageSize(maxSize);
809789
return this;
810790
}
811791

812792
/** Set the trusted TLS certificates. */
813793
public Builder trustedCertificates(final InputStream stream) {
814-
this.trustedCertificates = Preconditions.checkNotNull(stream);
794+
builder.trustedCertificates(stream);
815795
return this;
816796
}
817797

818798
/** Set the trusted TLS certificates. */
819799
public Builder clientCertificate(
820800
final InputStream clientCertificate, final InputStream clientKey) {
821-
Preconditions.checkNotNull(clientKey);
822-
this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
823-
this.clientKey = Preconditions.checkNotNull(clientKey);
801+
builder.clientCertificate(clientCertificate, clientKey);
824802
return this;
825803
}
826804

827805
public Builder allocator(BufferAllocator allocator) {
828-
this.allocator = Preconditions.checkNotNull(allocator);
806+
builder.allocator(allocator);
829807
return this;
830808
}
831809

832810
public Builder location(Location location) {
833-
this.location = Preconditions.checkNotNull(location);
811+
builder.location(location);
834812
return this;
835813
}
836814

837815
public Builder intercept(FlightClientMiddleware.Factory factory) {
838-
middleware.add(factory);
816+
builder.intercept(factory);
839817
return this;
840818
}
841819

842820
public Builder verifyServer(boolean verifyServer) {
843-
this.verifyServer = verifyServer;
821+
builder.verifyServer(verifyServer);
844822
return this;
845823
}
846824

847825
/** Create the client from this builder. */
848826
public FlightClient build() {
849-
final NettyChannelBuilder builder;
850-
851-
switch (location.getUri().getScheme()) {
852-
case LocationSchemes.GRPC:
853-
case LocationSchemes.GRPC_INSECURE:
854-
case LocationSchemes.GRPC_TLS:
855-
{
856-
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
857-
break;
858-
}
859-
case LocationSchemes.GRPC_DOMAIN_SOCKET:
860-
{
861-
// The implementation is platform-specific, so we have to find the classes at runtime
862-
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
863-
try {
864-
try {
865-
// Linux
866-
builder.channelType(
867-
Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")
868-
.asSubclass(ServerChannel.class));
869-
final EventLoopGroup elg =
870-
Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
871-
.asSubclass(EventLoopGroup.class)
872-
.getDeclaredConstructor()
873-
.newInstance();
874-
builder.eventLoopGroup(elg);
875-
} catch (ClassNotFoundException e) {
876-
// BSD
877-
builder.channelType(
878-
Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")
879-
.asSubclass(ServerChannel.class));
880-
final EventLoopGroup elg =
881-
Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
882-
.asSubclass(EventLoopGroup.class)
883-
.getDeclaredConstructor()
884-
.newInstance();
885-
builder.eventLoopGroup(elg);
886-
}
887-
} catch (ClassNotFoundException
888-
| InstantiationException
889-
| IllegalAccessException
890-
| NoSuchMethodException
891-
| InvocationTargetException e) {
892-
throw new UnsupportedOperationException(
893-
"Could not find suitable Netty native transport implementation for domain socket address.");
894-
}
895-
break;
896-
}
897-
default:
898-
throw new IllegalArgumentException(
899-
"Scheme is not supported: " + location.getUri().getScheme());
900-
}
901-
902-
if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
903-
builder.useTransportSecurity();
904-
905-
final boolean hasTrustedCerts = this.trustedCertificates != null;
906-
final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
907-
if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
908-
throw new IllegalArgumentException(
909-
"FlightClient has been configured to disable server verification, "
910-
+ "but certificate options have been specified.");
911-
}
912-
913-
final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();
914-
915-
if (!this.verifyServer) {
916-
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
917-
} else if (this.trustedCertificates != null
918-
|| this.clientCertificate != null
919-
|| this.clientKey != null) {
920-
if (this.trustedCertificates != null) {
921-
sslContextBuilder.trustManager(this.trustedCertificates);
922-
}
923-
if (this.clientCertificate != null && this.clientKey != null) {
924-
sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
925-
}
926-
}
927-
try {
928-
builder.sslContext(sslContextBuilder.build());
929-
} catch (SSLException e) {
930-
throw new RuntimeException(e);
931-
}
932-
933-
if (this.overrideHostname != null) {
934-
builder.overrideAuthority(this.overrideHostname);
935-
}
936-
} else {
937-
builder.usePlaintext();
938-
}
939-
940-
builder
941-
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
942-
.maxInboundMessageSize(maxInboundMessageSize)
943-
.maxInboundMetadataSize(maxInboundMessageSize);
944-
return new FlightClient(allocator, builder.build(), middleware);
827+
final NettyChannelBuilder channelBuilder = builder.build();
828+
return new FlightClient(builder.allocator(), channelBuilder.build(), builder.middleware());
945829
}
946830
}
947831

flight/flight-core/src/main/java/org/apache/arrow/flight/FlightGrpcUtils.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import io.grpc.ManagedChannel;
2424
import io.grpc.MethodDescriptor;
2525
import java.util.Collections;
26+
import java.util.List;
2627
import java.util.concurrent.ExecutorService;
2728
import java.util.concurrent.TimeUnit;
2829
import org.apache.arrow.flight.auth.ServerAuthHandler;
@@ -151,6 +152,19 @@ public static FlightClient createFlightClient(
151152
return new FlightClient(incomingAllocator, channel, Collections.emptyList());
152153
}
153154

155+
/**
156+
* Creates a Flight client.
157+
*
158+
* @param incomingAllocator Memory allocator
159+
* @param channel provides a connection to a gRPC server.
160+
*/
161+
public static FlightClient createFlightClient(
162+
BufferAllocator incomingAllocator,
163+
ManagedChannel channel,
164+
List<FlightClientMiddleware.Factory> middleware) {
165+
return new FlightClient(incomingAllocator, channel, middleware);
166+
}
167+
154168
/**
155169
* Creates a Flight client.
156170
*

0 commit comments

Comments
 (0)