Skip to content

Commit

Permalink
Merge pull request #31703 from alesj/gstork1
Browse files Browse the repository at this point in the history
  • Loading branch information
cescoffier authored Mar 10, 2023
2 parents b2e612f + bb3d84b commit c0fc995
Show file tree
Hide file tree
Showing 34 changed files with 856 additions and 151 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
import io.quarkus.grpc.runtime.config.GrpcClientBuildTimeConfig;
import io.quarkus.grpc.runtime.stork.GrpcStorkRecorder;
import io.quarkus.grpc.runtime.stork.StorkMeasuringGrpcInterceptor;
import io.quarkus.grpc.runtime.stork.VertxStorkMeasuringGrpcInterceptor;
import io.quarkus.grpc.runtime.supports.Channels;
import io.quarkus.grpc.runtime.supports.GrpcClientConfigProvider;
import io.quarkus.grpc.runtime.supports.IOThreadClientInterceptor;
Expand All @@ -97,6 +98,7 @@ void registerBeans(BuildProducer<AdditionalBeanBuildItem> beans) {
@BuildStep
void registerStorkInterceptor(BuildProducer<AdditionalBeanBuildItem> beans) {
beans.produce(new AdditionalBeanBuildItem(StorkMeasuringGrpcInterceptor.class));
beans.produce(new AdditionalBeanBuildItem(VertxStorkMeasuringGrpcInterceptor.class));
}

@BuildStep
Expand Down Expand Up @@ -407,6 +409,7 @@ SyntheticBeanBuildItem clientInterceptorStorage(GrpcClientRecorder recorder, Rec

// it's okay if this one is not used:
superfluousInterceptors.remove(StorkMeasuringGrpcInterceptor.class.getName());
superfluousInterceptors.remove(VertxStorkMeasuringGrpcInterceptor.class.getName());
if (!superfluousInterceptors.isEmpty()) {
LOGGER.warnf("At least one unused gRPC client interceptor found: %s. If there are meant to be used globally, " +
"annotate them with @GlobalInterceptor.", String.join(", ", superfluousInterceptors));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,12 @@ public void initializeGrpcServer(RuntimeValue<Vertx> vertxSupplier,
if (grpcContainer == null) {
throw new IllegalStateException("gRPC not initialized, GrpcContainer not found");
}
Vertx vertx = vertxSupplier.getValue();
if (hasNoServices(grpcContainer.getServices()) && LaunchMode.current() != LaunchMode.DEVELOPMENT) {
LOGGER.error("Unable to find beans exposing the `BindableService` interface - not starting the gRPC server");
return; // OK?
}

Vertx vertx = vertxSupplier.getValue();
GrpcServerConfiguration configuration = cfg.server;
GrpcBuilderProvider<?> provider = GrpcBuilderProvider.findServerBuilderProvider(configuration);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ public class GrpcClientConfiguration {
@ConfigItem
public InProcess inProcess;

/**
* Configure Stork usage with new Vert.x gRPC, if enabled.
*/
@ConfigItem
public StorkConfig stork;

/**
* The gRPC service port.
*/
Expand Down Expand Up @@ -168,7 +174,7 @@ public class GrpcClientConfiguration {

/**
* Use a custom load balancing policy.
* Accepted values are: {@code pick_value}, {@code round_robin}, {@code grpclb}.
* Accepted values are: {@code pick_first}, {@code round_robin}, {@code grpclb}.
* This value is ignored if name-resolver is set to 'stork'.
*/
@ConfigItem(defaultValue = "pick_first")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package io.quarkus.grpc.runtime.config;

import io.quarkus.runtime.annotations.ConfigGroup;
import io.quarkus.runtime.annotations.ConfigItem;

/**
* Stork config for new Vert.x gRPC
*/
@ConfigGroup
public class StorkConfig {
/**
* Number of threads on a delayed gRPC ClientCall
*/
@ConfigItem(defaultValue = "10")
public int threads;

/**
* Deadline in milliseconds of delayed gRPC call
*/
@ConfigItem(defaultValue = "5000")
public long deadline;

/**
* Number of retries on a gRPC ClientCall
*/
@ConfigItem(defaultValue = "3")
public int retries;

/**
* Initial delay in seconds on refresh check
*/
@ConfigItem(defaultValue = "60")
public long delay;

/**
* Refresh period in seconds
*/
@ConfigItem(defaultValue = "120")
public long period;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.quarkus.grpc.runtime.stork;

import io.grpc.ClientCall;
import io.grpc.ForwardingClientCall;
import io.smallrye.stork.api.ServiceInstance;

abstract class AbstractStorkMeasuringCall<ReqT, RespT> extends ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>
implements StorkMeasuringCollector {
final boolean recordTime;

protected AbstractStorkMeasuringCall(ClientCall<ReqT, RespT> delegate, boolean recordTime) {
super(delegate);
this.recordTime = recordTime;
}

protected abstract ServiceInstance serviceInstance();

public void recordReply() {
if (serviceInstance() != null && recordTime) {
serviceInstance().recordReply();
}
}

public void recordEnd(Throwable error) {
if (serviceInstance() != null) {
serviceInstance().recordEnd(error);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import static io.grpc.ConnectivityState.IDLE;
import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
import static io.quarkus.grpc.runtime.stork.StorkMeasuringGrpcInterceptor.STORK_MEASURE_TIME;
import static io.quarkus.grpc.runtime.stork.StorkMeasuringGrpcInterceptor.STORK_SERVICE_INSTANCE;
import static io.quarkus.grpc.runtime.stork.StorkMeasuringCollector.STORK_MEASURE_TIME;
import static io.quarkus.grpc.runtime.stork.StorkMeasuringCollector.STORK_SERVICE_INSTANCE;

import java.util.Collections;
import java.util.Comparator;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ private void informListener(List<ServiceInstance> instances) {
socketAddresses.add(new InetSocketAddress(inetAddress, instance.getPort()));
}
} catch (UnknownHostException e) {
log.errorf(e, "Ignoring wrong host: '%s' for service name '%s'", instance.getHost(),
log.warnf(e, "Ignoring wrong host: '%s' for service name '%s'", instance.getHost(),
serviceName);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package io.quarkus.grpc.runtime.stork;

import static io.quarkus.grpc.runtime.stork.StorkMeasuringCollector.STORK_MEASURE_TIME;
import static io.quarkus.grpc.runtime.stork.StorkMeasuringCollector.STORK_SERVICE_INSTANCE;

import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import javax.annotation.Nullable;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.Deadline;
import io.grpc.MethodDescriptor;
import io.grpc.internal.DelayedClientCall;
import io.quarkus.grpc.runtime.config.StorkConfig;
import io.smallrye.mutiny.Uni;
import io.smallrye.stork.Stork;
import io.smallrye.stork.api.Service;
import io.smallrye.stork.api.ServiceInstance;
import io.vertx.core.net.SocketAddress;
import io.vertx.grpc.client.GrpcClient;
import io.vertx.grpc.client.GrpcClientChannel;

public class StorkGrpcChannel extends Channel implements AutoCloseable {
private static final Logger log = LoggerFactory.getLogger(StorkGrpcChannel.class);

private final Map<Long, ServiceInstance> services = new ConcurrentHashMap<>();
private final Map<Long, Channel> channels = new ConcurrentHashMap<>();
private final ScheduledExecutorService scheduler;

private final GrpcClient client;
private final String serviceName;
private final StorkConfig stork;
private final Executor executor;

private static class Context {
Service service;
boolean measureTime;
ServiceInstance instance;
InetSocketAddress address;
Channel channel;
AtomicReference<ServiceInstance> ref;
}

public StorkGrpcChannel(GrpcClient client, String serviceName, StorkConfig stork, Executor executor) {
this.client = client;
this.serviceName = serviceName;
this.stork = stork;
this.executor = executor;
this.scheduler = new ScheduledThreadPoolExecutor(stork.threads);
this.scheduler.scheduleAtFixedRate(this::refresh, stork.delay, stork.period, TimeUnit.SECONDS);
}

@Override
public <RequestT, ResponseT> ClientCall<RequestT, ResponseT> newCall(MethodDescriptor<RequestT, ResponseT> methodDescriptor,
CallOptions callOptions) {
Service service = Stork.getInstance().getService(serviceName);
if (service == null) {
throw new IllegalStateException("No service definition for serviceName " + serviceName + " found.");
}

Context context = new Context();
context.service = service;
// handle this calls here
Boolean measureTime = STORK_MEASURE_TIME.get();
context.measureTime = measureTime != null && measureTime;
context.ref = STORK_SERVICE_INSTANCE.get();

DelayedClientCall<RequestT, ResponseT> delayed = new StorkDelayedClientCall<>(executor, scheduler,
Deadline.after(stork.deadline, TimeUnit.MILLISECONDS));

asyncCall(methodDescriptor, callOptions, context)
.onFailure()
.retry()
.atMost(stork.retries)
.subscribe()
.asCompletionStage()
.thenApply(delayed::setCall)
.thenAccept(Runnable::run)
.exceptionally(t -> {
delayed.cancel("Failed to create new Stork ClientCall", t);
return null;
});

return delayed;
}

private <RequestT, ResponseT> Uni<ClientCall<RequestT, ResponseT>> asyncCall(
MethodDescriptor<RequestT, ResponseT> methodDescriptor, CallOptions callOptions, Context context) {
Uni<Context> entry = pickServiceInstanceWithChannel(context);
return entry.map(c -> {
ServiceInstance instance = c.instance;
long serviceId = instance.getId();
Channel channel = c.channel;
try {
services.put(serviceId, instance);
channels.put(serviceId, channel);
return channel.newCall(methodDescriptor, callOptions);
} catch (Exception ex) {
// remove, no good
services.remove(serviceId);
channels.remove(serviceId);
throw new IllegalStateException(ex);
}
});
}

@Override
public String authority() {
return null;
}

@Override
public void close() {
scheduler.shutdown();
}

@Override
public String toString() {
return super.toString() + String.format(" [%s]", serviceName);
}

private void refresh() {
// any better way to know which are OK / bad?
services.clear();
channels.clear();
}

private Uni<Context> pickServiceInstanceWithChannel(Context context) {
Uni<ServiceInstance> uni = pickServerInstance(context.service, context.measureTime);
return uni
.map(si -> {
context.instance = si;
if (si.gatherStatistics() && context.ref != null) {
context.ref.set(si);
}
return context;
})
.invoke(this::checkSocketAddress)
.invoke(c -> {
ServiceInstance instance = context.instance;
InetSocketAddress isa = context.address;
context.channel = channels.computeIfAbsent(instance.getId(), id -> {
SocketAddress address = SocketAddress.inetSocketAddress(isa.getPort(), isa.getHostName());
return new GrpcClientChannel(client, address);
});
});
}

private Uni<ServiceInstance> pickServerInstance(Service service, boolean measureTime) {
return Uni.createFrom()
.deferred(() -> {
if (services.isEmpty()) {
return service.getInstances()
.invoke(l -> l.forEach(s -> services.put(s.getId(), s)));
} else {
List<ServiceInstance> list = new ArrayList<>(services.values());
return Uni.createFrom().item(list);
}
})
.invoke(list -> {
// list should not be empty + sort by id
list.sort(Comparator.comparing(ServiceInstance::getId));
})
.map(list -> service.selectInstanceAndRecordStart(list, measureTime));
}

private void checkSocketAddress(Context context) {
ServiceInstance instance = context.instance;
Set<InetSocketAddress> socketAddresses = new HashSet<>();
try {
for (InetAddress inetAddress : InetAddress.getAllByName(instance.getHost())) {
socketAddresses.add(new InetSocketAddress(inetAddress, instance.getPort()));
}
} catch (UnknownHostException e) {
log.warn("Ignoring wrong host: '{}' for service name '{}'", instance.getHost(), serviceName, e);
}

if (!socketAddresses.isEmpty()) {
context.address = socketAddresses.iterator().next(); // pick first
} else {
long serviceId = instance.getId();
services.remove(serviceId);
channels.remove(serviceId);
throw new IllegalStateException("Failed to determine working socket addresses for service-name: " + serviceName);
}
}

private static class StorkDelayedClientCall<RequestT, ResponseT> extends DelayedClientCall<RequestT, ResponseT> {
public StorkDelayedClientCall(Executor callExecutor, ScheduledExecutorService scheduler, @Nullable Deadline deadline) {
super(callExecutor, scheduler, deadline);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.quarkus.grpc.runtime.stork;

import io.grpc.ClientCall;
import io.grpc.ForwardingClientCall;
import io.smallrye.stork.api.ServiceInstance;

abstract class StorkMeasuringCall<ReqT, RespT> extends ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>
implements StorkMeasuringCollector {
final boolean recordTime;

protected StorkMeasuringCall(ClientCall<ReqT, RespT> delegate, boolean recordTime) {
super(delegate);
this.recordTime = recordTime;
}

protected abstract ServiceInstance serviceInstance();

public void recordReply() {
if (serviceInstance() != null && recordTime) {
serviceInstance().recordReply();
}
}

public void recordEnd(Throwable error) {
if (serviceInstance() != null) {
serviceInstance().recordEnd(error);
}
}
}
Loading

0 comments on commit c0fc995

Please sign in to comment.