Skip to content

Commit

Permalink
feat(callback): support configurable heartbeat interval (#360)
Browse files Browse the repository at this point in the history
Update protocol to support configurable heartbeat interval. In order to
keep backwards compatibility with older version of protocol, if no
heartbeat value is provided, we default to 5s.

Updated protocol to also support both `camelCase` and `snake_case`
extension values.
  • Loading branch information
dariuszkuc authored Nov 27, 2023
1 parent 42ae752 commit 9cc070c
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,23 @@
* identity
*/
public record SubscriptionCallback(
@NotNull String callback_url, @NotNull String subscription_id, @NotNull String verifier) {
@NotNull String callback_url,
@NotNull String subscription_id,
@NotNull String verifier,
int heartbeatIntervalMs) {

public static String SUBSCRIPTION_EXTENSION = "subscription";
public static String CALLBACK_URL = "callbackUrl";
public static String SUBSCRIPTION_ID = "subscriptionId";
public static String VERIFIER = "verifier";
public static String HEARTBEAT_INTERVAL_MS = "heartbeatIntervalMs";

// added for backwards compatibility with non GA callback protocol
// will be removed in next major release
@Deprecated private static int DEFAULT_HEARTBEAT_INTERVAL_MS = 5000;
@Deprecated private static String DEPRECATED_CALLBACK_URL = "callback_url";
@Deprecated private static String DEPRECATED_SUBSCRIPTION_ID = "subscription_id";
@Deprecated private static String DEPRECATED_HEARTBEAT_INTERVAL_MS = "heartbeat_interval_ms";

/**
* Parse subscription callback information from GraphQL request extension.
Expand All @@ -28,20 +44,43 @@ public record SubscriptionCallback(
@NotNull
public static Mono<SubscriptionCallback> parseSubscriptionCallbackExtension(
@NotNull Map<String, Object> extensions) {
var subscription_extension = extensions.get("subscription");
var subscription_extension = extensions.get(SUBSCRIPTION_EXTENSION);
if (subscription_extension instanceof Map subscription) {
var callback_url = subscription.get("callback_url");
var subscription_id = subscription.get("subscription_id");
var verifier = subscription.get("verifier");
var callback_url = subscription.get(CALLBACK_URL);
if (callback_url == null) {
callback_url = subscription.get(DEPRECATED_CALLBACK_URL);
}
var subscription_id = subscription.get(SUBSCRIPTION_ID);
if (subscription_id == null) {
subscription_id = subscription.get(DEPRECATED_SUBSCRIPTION_ID);
}
var verifier = subscription.get(VERIFIER);
var rawHeartbeatMs = subscription.get(HEARTBEAT_INTERVAL_MS);
if (rawHeartbeatMs == null) {
rawHeartbeatMs = subscription.get(DEPRECATED_HEARTBEAT_INTERVAL_MS);
}
var heartbeatMs = parseHeartbeats(rawHeartbeatMs);

if (callback_url != null && subscription_id != null && verifier != null) {
if (callback_url != null && subscription_id != null && verifier != null && heartbeatMs >= 0) {
return Mono.just(
new SubscriptionCallback(
(String) callback_url, (String) subscription_id, (String) verifier));
(String) callback_url, (String) subscription_id, (String) verifier, heartbeatMs));
} else {
return Mono.error(new InvalidCallbackExtensionException(subscription));
}
}
return Mono.error(new CallbackExtensionNotSpecifiedException());
}

private static int parseHeartbeats(Object heartbeatMs) {
if (heartbeatMs != null) {
try {
return (Integer) heartbeatMs;
} catch (ClassCastException e) {
// heartbeat_interval_ms is not an integer
return -1;
}
}
return DEFAULT_HEARTBEAT_INTERVAL_MS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,14 @@ protected Flux<SubscritionCallbackMessage> startSubscription(
@NotNull WebClient callbackClient,
@NotNull WebGraphQlRequest graphQlRequest,
@NotNull SubscriptionCallback callback) {
// infinite heartbeat flux
var checkMessage = new CallbackMessageCheck(callback.subscription_id(), callback.verifier());
Flux<SubscritionCallbackMessage> heartbeatFlux =
heartbeatFlux(callbackClient, checkMessage, callback);
// infinite heartbeat flux OR no heartbeat
Flux<SubscritionCallbackMessage> heartbeatFlux;
if (callback.heartbeatIntervalMs() > 0) {
var checkMessage = new CallbackMessageCheck(callback.subscription_id(), callback.verifier());
heartbeatFlux = heartbeatFlux(callbackClient, checkMessage, callback);
} else {
heartbeatFlux = Flux.empty();
}

// subscription data flux
Flux<SubscritionCallbackMessage> subscriptionFlux =
Expand Down Expand Up @@ -173,7 +177,7 @@ protected Flux<SubscritionCallbackMessage> startSubscription(
private Flux<SubscritionCallbackMessage> heartbeatFlux(
WebClient client, CallbackMessageCheck check, SubscriptionCallback callback) {
return Flux.just(check)
.delayElements(Duration.ofMillis(5000))
.delayElements(Duration.ofMillis(callback.heartbeatIntervalMs()))
.publishOn(scheduler)
.concatMap(
(heartbeat) ->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
package com.apollographql.subscription;

import static com.apollographql.subscription.callback.SubscriptionCallback.CALLBACK_URL;
import static com.apollographql.subscription.callback.SubscriptionCallback.HEARTBEAT_INTERVAL_MS;
import static com.apollographql.subscription.callback.SubscriptionCallback.SUBSCRIPTION_EXTENSION;
import static com.apollographql.subscription.callback.SubscriptionCallback.SUBSCRIPTION_ID;
import static com.apollographql.subscription.callback.SubscriptionCallback.VERIFIER;

import java.util.HashMap;
import java.util.Map;

public class CallbackTestUtils {
public static Map<String, Object> createMockGraphQLRequest(
String subscriptionId, String callbackUrl) {
var subscriptionExtension = new HashMap<String, Object>();
subscriptionExtension.put("callback_url", callbackUrl);
subscriptionExtension.put("subscription_id", subscriptionId);
subscriptionExtension.put("verifier", "junit");
subscriptionExtension.put(CALLBACK_URL, callbackUrl);
subscriptionExtension.put(SUBSCRIPTION_ID, subscriptionId);
subscriptionExtension.put(VERIFIER, "junit");
subscriptionExtension.put(HEARTBEAT_INTERVAL_MS, 5000);
return Map.of(
"query",
"subscription { counter }",
"extensions",
Map.of("subscription", subscriptionExtension));
Map.of(SUBSCRIPTION_EXTENSION, subscriptionExtension));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class SubscriptionCallbackHandlerTest {

static class MockWebHandler implements WebGraphQlHandler {

private Flux subscriptionFlux;
private final Flux subscriptionFlux;

public MockWebHandler(Flux subscriptionFlux) {
this.subscriptionFlux = subscriptionFlux;
Expand All @@ -53,7 +53,7 @@ public WebSocketGraphQlInterceptor getWebSocketInterceptor() {
}

@Override
public Mono<WebGraphQlResponse> handleRequest(WebGraphQlRequest request) {
public @NotNull Mono<WebGraphQlResponse> handleRequest(@NotNull WebGraphQlRequest request) {
var executionResult = ExecutionResult.newExecutionResult().data(subscriptionFlux).build();
var executionResponse =
new DefaultExecutionGraphQlResponse(request.toExecutionInput(), executionResult);
Expand Down Expand Up @@ -86,7 +86,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
var verifier = "junit";
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier);
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier, 5000);

var graphQLRequest = stubWebGraphQLRequest(subscriptionId, callbackUrl);
var subscription = handler.handleSubscriptionUsingCallback(graphQLRequest, callback);
Expand Down Expand Up @@ -137,7 +137,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
var verifier = "junit";
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier);
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier, 5000);

var graphQLRequest = stubWebGraphQLRequest(subscriptionId, callbackUrl);
var subscription = handler.handleSubscriptionUsingCallback(graphQLRequest, callback);
Expand Down Expand Up @@ -176,7 +176,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
var verifier = "junit";
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier);
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier, 5000);
var client = WebClient.builder().baseUrl(callback.callback_url()).build();

var graphQLRequest = stubWebGraphQLRequest(subscriptionId, callbackUrl);
Expand Down Expand Up @@ -241,7 +241,7 @@ public MockResponse dispatch(@NotNull RecordedRequest recordedRequest) {
var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
var verifier = "junit";
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier);
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier, 5000);
var client = WebClient.builder().baseUrl(callback.callback_url()).build();

var graphQLRequest = stubWebGraphQLRequest(subscriptionId, callbackUrl);
Expand Down Expand Up @@ -285,7 +285,36 @@ public void subscription_success() {
var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
var verifier = "junit";
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier);
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier, 5000);
var client = WebClient.builder().baseUrl(callback.callback_url()).build();

var graphQLRequest = stubWebGraphQLRequest(subscriptionId, callbackUrl);
var subscription = handler.startSubscription(client, graphQLRequest, callback);
StepVerifier.create(subscription)
.expectNext(nextMessage(subscriptionId, verifier, 1))
.expectNext(nextMessage(subscriptionId, verifier, 2))
.expectNext(new CallbackMessageComplete(subscriptionId, verifier))
.verifyComplete();
} catch (IOException e) {
// failed to close the server
}
}

@Test
public void subscription_success_without_heartbeats() {
try (var server = new MockWebServer()) {
mockServerResponses(server, HttpStatus.OK, HttpStatus.OK, HttpStatus.ACCEPTED);

var data =
Flux.just(1, 2)
.delayElements(Duration.ofMillis(50))
.map((i) -> ExecutionResult.newExecutionResult().data(Map.of("counter", i)).build());
var handler = new SubscriptionCallbackHandler(new MockWebHandler(data));

var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
var verifier = "junit";
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier, 0);
var client = WebClient.builder().baseUrl(callback.callback_url()).build();

var graphQLRequest = stubWebGraphQLRequest(subscriptionId, callbackUrl);
Expand Down Expand Up @@ -315,7 +344,7 @@ public void subscription_exception() {
var subscriptionId = UUID.randomUUID().toString();
var callbackUrl = server.url("/callback/" + subscriptionId).toString();
var verifier = "junit";
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier);
var callback = new SubscriptionCallback(callbackUrl, subscriptionId, verifier, 5000);
var client = WebClient.builder().baseUrl(callback.callback_url()).build();

var graphQLRequest = stubWebGraphQLRequest(subscriptionId, callbackUrl);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package com.apollographql.subscription.callback;

import static com.apollographql.subscription.callback.SubscriptionCallback.CALLBACK_URL;
import static com.apollographql.subscription.callback.SubscriptionCallback.HEARTBEAT_INTERVAL_MS;
import static com.apollographql.subscription.callback.SubscriptionCallback.SUBSCRIPTION_EXTENSION;
import static com.apollographql.subscription.callback.SubscriptionCallback.SUBSCRIPTION_ID;
import static com.apollographql.subscription.callback.SubscriptionCallback.VERIFIER;

import com.apollographql.subscription.exception.CallbackExtensionNotSpecifiedException;
import com.apollographql.subscription.exception.InvalidCallbackExtensionException;
import java.util.Map;
import org.junit.jupiter.api.Test;
import reactor.test.StepVerifier;

public class SubscriptionCallbackTest {

@Test
public void callback_valid() {
var expected = new SubscriptionCallback("foo.com", "1234567890", "junit", 1000);
Map<String, Object> extension =
Map.of(
SUBSCRIPTION_EXTENSION,
Map.of(
CALLBACK_URL, expected.callback_url(),
SUBSCRIPTION_ID, expected.subscription_id(),
VERIFIER, expected.verifier(),
HEARTBEAT_INTERVAL_MS, expected.heartbeatIntervalMs()));
var callback = SubscriptionCallback.parseSubscriptionCallbackExtension(extension);
StepVerifier.create(callback).expectNext(expected).verifyComplete();
}

@Test
public void callback_missingExtension_returnsError() {
var callback = SubscriptionCallback.parseSubscriptionCallbackExtension(Map.of());
StepVerifier.create(callback)
.expectError(CallbackExtensionNotSpecifiedException.class)
.verify();
}

@Test
public void callback_missingCallbackUrl_returnsError() {
var callback =
SubscriptionCallback.parseSubscriptionCallbackExtension(
Map.of(
SUBSCRIPTION_EXTENSION,
Map.of(
SUBSCRIPTION_ID, "123",
VERIFIER, "junit",
HEARTBEAT_INTERVAL_MS, 1000)));
StepVerifier.create(callback).expectError(InvalidCallbackExtensionException.class).verify();
}

@Test
public void callback_missingHeartbeat_defaults5s() {
var expected = new SubscriptionCallback("foo.com", "1234567890", "junit", 5000);
Map<String, Object> extension =
Map.of(
SUBSCRIPTION_EXTENSION,
Map.of(
CALLBACK_URL, expected.callback_url(),
SUBSCRIPTION_ID, expected.subscription_id(),
VERIFIER, expected.verifier()));
var callback = SubscriptionCallback.parseSubscriptionCallbackExtension(extension);
StepVerifier.create(callback).expectNext(expected).verifyComplete();
}

@Test
public void callback_nonIntegerHeartbeat_returnsError() {
var callback =
SubscriptionCallback.parseSubscriptionCallbackExtension(
Map.of(
SUBSCRIPTION_EXTENSION,
Map.of(
CALLBACK_URL, "foo.com",
SUBSCRIPTION_ID, "123",
VERIFIER, "junit",
HEARTBEAT_INTERVAL_MS, "100")));
StepVerifier.create(callback).expectError(InvalidCallbackExtensionException.class).verify();
}

@Test
public void callback_negativeHeartbeat_returnsError() {
var callback =
SubscriptionCallback.parseSubscriptionCallbackExtension(
Map.of(
SUBSCRIPTION_EXTENSION,
Map.of(
CALLBACK_URL, "foo.com",
SUBSCRIPTION_ID, "123",
VERIFIER, "junit",
HEARTBEAT_INTERVAL_MS, -100)));
StepVerifier.create(callback).expectError(InvalidCallbackExtensionException.class).verify();
}

@Test
public void callback_usingSnakeCase_valid() {
var expected = new SubscriptionCallback("foo.com", "1234567890", "junit", 1000);
Map<String, Object> extension =
Map.of(
SUBSCRIPTION_EXTENSION,
Map.of(
"callback_url",
expected.callback_url(),
"subscription_id",
expected.subscription_id(),
VERIFIER,
expected.verifier(),
"heartbeat_interval_ms",
expected.heartbeatIntervalMs()));
var callback = SubscriptionCallback.parseSubscriptionCallbackExtension(extension);
StepVerifier.create(callback).expectNext(expected).verifyComplete();
}
}

0 comments on commit 9cc070c

Please sign in to comment.