Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,13 @@
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.NettyChannelBuilder;
import io.grpc.stub.ClientCallStreamObserver;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.ClientResponseObserver;
import io.grpc.stub.StreamObserver;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.ServerChannel;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.InvocationTargetException;
import java.net.URISyntaxException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
Expand All @@ -45,7 +39,6 @@
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.function.BooleanSupplier;
import javax.net.ssl.SSLException;
import org.apache.arrow.flight.FlightProducer.StreamListener;
import org.apache.arrow.flight.auth.BasicClientAuthHandler;
import org.apache.arrow.flight.auth.ClientAuthHandler;
Expand All @@ -57,6 +50,7 @@
import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware;
import org.apache.arrow.flight.grpc.ClientInterceptorAdapter;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.flight.grpc.NettyClientBuilder;
import org.apache.arrow.flight.grpc.StatusUtils;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.impl.Flight.Empty;
Expand All @@ -73,12 +67,6 @@
public class FlightClient implements AutoCloseable {
private static final int PENDING_REQUESTS = 5;

/**
* The maximum number of trace events to keep on the gRPC Channel. This value disables channel
* tracing.
*/
private static final int MAX_CHANNEL_TRACE_EVENTS = 0;

private final BufferAllocator allocator;
private final ManagedChannel channel;

Expand All @@ -97,11 +85,12 @@ public class FlightClient implements AutoCloseable {
List<FlightClientMiddleware.Factory> middleware) {
this.allocator = incomingAllocator.newChildAllocator("flight-client", 0, Long.MAX_VALUE);
this.channel = channel;
this.middleware = middleware;
// We need a mutable copy (shared between this class and ClientInterceptorAdapter)
this.middleware = new ArrayList<>(middleware);

final ClientInterceptor[] interceptors;
interceptors =
new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(middleware)};
new ClientInterceptor[] {authInterceptor, new ClientInterceptorAdapter(this.middleware)};

// Create a channel with interceptors pre-applied for DoGet and DoPut
Channel interceptedChannel = ClientInterceptors.intercept(channel, interceptors);
Expand Down Expand Up @@ -772,176 +761,71 @@ public static Builder builder(BufferAllocator allocator, Location location) {

/** A builder for Flight clients. */
public static final class Builder {
private BufferAllocator allocator;
private Location location;
private boolean forceTls = false;
private int maxInboundMessageSize = FlightServer.MAX_GRPC_MESSAGE_SIZE;
private InputStream trustedCertificates = null;
private InputStream clientCertificate = null;
private InputStream clientKey = null;
private String overrideHostname = null;
private List<FlightClientMiddleware.Factory> middleware = new ArrayList<>();
private boolean verifyServer = true;

private Builder() {}
private final NettyClientBuilder builder;

private Builder() {
this.builder = new NettyClientBuilder();
}

private Builder(BufferAllocator allocator, Location location) {
this.allocator = Preconditions.checkNotNull(allocator);
this.location = Preconditions.checkNotNull(location);
this.builder = new NettyClientBuilder(allocator, location);
}

/** Force the client to connect over TLS. */
public Builder useTls() {
this.forceTls = true;
builder.useTls();
return this;
}

/** Override the hostname checked for TLS. Use with caution in production. */
public Builder overrideHostname(final String hostname) {
this.overrideHostname = hostname;
builder.overrideHostname(hostname);
return this;
}

/** Set the maximum inbound message size. */
public Builder maxInboundMessageSize(int maxSize) {
Preconditions.checkArgument(maxSize > 0);
this.maxInboundMessageSize = maxSize;
builder.maxInboundMessageSize(maxSize);
return this;
}

/** Set the trusted TLS certificates. */
public Builder trustedCertificates(final InputStream stream) {
this.trustedCertificates = Preconditions.checkNotNull(stream);
builder.trustedCertificates(stream);
return this;
}

/** Set the trusted TLS certificates. */
public Builder clientCertificate(
final InputStream clientCertificate, final InputStream clientKey) {
Preconditions.checkNotNull(clientKey);
this.clientCertificate = Preconditions.checkNotNull(clientCertificate);
this.clientKey = Preconditions.checkNotNull(clientKey);
builder.clientCertificate(clientCertificate, clientKey);
return this;
}

public Builder allocator(BufferAllocator allocator) {
this.allocator = Preconditions.checkNotNull(allocator);
builder.allocator(allocator);
return this;
}

public Builder location(Location location) {
this.location = Preconditions.checkNotNull(location);
builder.location(location);
return this;
}

public Builder intercept(FlightClientMiddleware.Factory factory) {
middleware.add(factory);
builder.intercept(factory);
return this;
}

public Builder verifyServer(boolean verifyServer) {
this.verifyServer = verifyServer;
builder.verifyServer(verifyServer);
return this;
}

/** Create the client from this builder. */
public FlightClient build() {
final NettyChannelBuilder builder;

switch (location.getUri().getScheme()) {
case LocationSchemes.GRPC:
case LocationSchemes.GRPC_INSECURE:
case LocationSchemes.GRPC_TLS:
{
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
break;
}
case LocationSchemes.GRPC_DOMAIN_SOCKET:
{
// The implementation is platform-specific, so we have to find the classes at runtime
builder = NettyChannelBuilder.forAddress(location.toSocketAddress());
try {
try {
// Linux
builder.channelType(
Class.forName("io.netty.channel.epoll.EpollDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.epoll.EpollEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getDeclaredConstructor()
.newInstance();
builder.eventLoopGroup(elg);
} catch (ClassNotFoundException e) {
// BSD
builder.channelType(
Class.forName("io.netty.channel.kqueue.KQueueDomainSocketChannel")
.asSubclass(ServerChannel.class));
final EventLoopGroup elg =
Class.forName("io.netty.channel.kqueue.KQueueEventLoopGroup")
.asSubclass(EventLoopGroup.class)
.getDeclaredConstructor()
.newInstance();
builder.eventLoopGroup(elg);
}
} catch (ClassNotFoundException
| InstantiationException
| IllegalAccessException
| NoSuchMethodException
| InvocationTargetException e) {
throw new UnsupportedOperationException(
"Could not find suitable Netty native transport implementation for domain socket address.");
}
break;
}
default:
throw new IllegalArgumentException(
"Scheme is not supported: " + location.getUri().getScheme());
}

if (this.forceTls || LocationSchemes.GRPC_TLS.equals(location.getUri().getScheme())) {
builder.useTransportSecurity();

final boolean hasTrustedCerts = this.trustedCertificates != null;
final boolean hasKeyCertPair = this.clientCertificate != null && this.clientKey != null;
if (!this.verifyServer && (hasTrustedCerts || hasKeyCertPair)) {
throw new IllegalArgumentException(
"FlightClient has been configured to disable server verification, "
+ "but certificate options have been specified.");
}

final SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient();

if (!this.verifyServer) {
sslContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE);
} else if (this.trustedCertificates != null
|| this.clientCertificate != null
|| this.clientKey != null) {
if (this.trustedCertificates != null) {
sslContextBuilder.trustManager(this.trustedCertificates);
}
if (this.clientCertificate != null && this.clientKey != null) {
sslContextBuilder.keyManager(this.clientCertificate, this.clientKey);
}
}
try {
builder.sslContext(sslContextBuilder.build());
} catch (SSLException e) {
throw new RuntimeException(e);
}

if (this.overrideHostname != null) {
builder.overrideAuthority(this.overrideHostname);
}
} else {
builder.usePlaintext();
}

builder
.maxTraceEvents(MAX_CHANNEL_TRACE_EVENTS)
.maxInboundMessageSize(maxInboundMessageSize)
.maxInboundMetadataSize(maxInboundMessageSize);
return new FlightClient(allocator, builder.build(), middleware);
final NettyChannelBuilder channelBuilder = builder.build();
return new FlightClient(builder.allocator(), channelBuilder.build(), builder.middleware());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import org.apache.arrow.flight.auth.ServerAuthHandler;
Expand Down Expand Up @@ -151,6 +152,19 @@ public static FlightClient createFlightClient(
return new FlightClient(incomingAllocator, channel, Collections.emptyList());
}

/**
* Creates a Flight client.
*
* @param incomingAllocator Memory allocator
* @param channel provides a connection to a gRPC server.
*/
public static FlightClient createFlightClient(
BufferAllocator incomingAllocator,
ManagedChannel channel,
List<FlightClientMiddleware.Factory> middleware) {
return new FlightClient(incomingAllocator, channel, middleware);
}

/**
* Creates a Flight client.
*
Expand Down
Loading
Loading