diff --git a/docs/assets/gateway-routing-flow.svg b/docs/assets/gateway-routing-flow.svg
new file mode 100644
index 000000000..07db084f8
--- /dev/null
+++ b/docs/assets/gateway-routing-flow.svg
@@ -0,0 +1,1665 @@
+
+
+
+
+
+
+
+
+
+
+
+
+ Request
+ Client
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ Send request to random member of routing group
+
+ Send request to least loaded member of Routing Group
+ Is Adaptive Load Balancing enabled?
+ Select "adhoc" routing group
+ Select Routing Group from highest priority matched rule
+ Does request match a Routing Rule?
+ Is request related to a previous request?
+ Send to same backend as previous
+
+
+
+
+
diff --git a/docs/design.md b/docs/design.md
index 91befeaa8..5cf2e492d 100644
--- a/docs/design.md
+++ b/docs/design.md
@@ -18,7 +18,8 @@
Gateway API
Resource groups API
Routing rules
-
+ Routing logic
+
diff --git a/docs/gateway-api.md b/docs/gateway-api.md
index 81388107b..558a8da1e 100644
--- a/docs/gateway-api.md
+++ b/docs/gateway-api.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/installation.md b/docs/installation.md
index 0fee2b91e..137ad453e 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/operation.md b/docs/operation.md
index e0c167f9e..c107899b0 100644
--- a/docs/operation.md
+++ b/docs/operation.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/quickstart.md b/docs/quickstart.md
index 73a97240c..d5730e443 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/references.md b/docs/references.md
index 884914642..d8df9238d 100644
--- a/docs/references.md
+++ b/docs/references.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/release-notes.md b/docs/release-notes.md
index b9c04da0d..78cc3a6a5 100644
--- a/docs/release-notes.md
+++ b/docs/release-notes.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/resource-groups-api.md b/docs/resource-groups-api.md
index 330e93616..9e4680f8b 100644
--- a/docs/resource-groups-api.md
+++ b/docs/resource-groups-api.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/routing-logic.md b/docs/routing-logic.md
new file mode 100644
index 000000000..a3c3ed18b
--- /dev/null
+++ b/docs/routing-logic.md
@@ -0,0 +1,111 @@
+**Trino Gateway documentation**
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+# Routing Logic
+
+## Overview
+
+Trino Gateway checks incoming requests to see if they're related to previous
+ones it handled. If they are, then Trino Gateway sends them to the same backend
+that dealt with the earlier requests.
+
+If it is a new request, the Trino Gateway refers to [Routing rules](routing-rules.md)
+to decide which group of backends, called a 'Routing Group,' should handle it.
+It then picks a backend from that Routing Group to handle the request using
+either an adaptive or round-robin strategy.
+
+
+
+## Sticky routing
+
+A request related to an ongoing process, or to state maintained on a single
+backend cluster, must be routed to that backend for proper handling. Two
+mechanisms for identifying related requests are currently implemented. By default,
+only routing based on query identifier is enabled.
+
+### Routing based on query identifier (default)
+
+When a query is initiated through the Trino Gateway, the query id will be
+extracted from the response and mapped to the backend that provided the
+response. Any subsequent request containing that query id will be forwarded
+to that backend. For example, to retrieve query results, the trino client
+polls a URI of the form
+`v1/statement/executing/queryid/nonce/counter`. The Trino Gateway will extract
+the queryid from this URI.
+
+### Routing based on cookies
+
+OAuth2 authentication requires that the same backend is used for each step of
+the handshake. When `gatewayCookieConfiguration.enabled` is set to true, a cookie
+will be added to requests made to paths beginning with `/oauth2` unless they already have
+a cookie present, which is used to route further `/oauth2/*` requests to the correct backend.
+Cookies are not added to requests to `v1/*` and other Trino endpoints.
+
+Trino Gateway signs its cookies to ensure that they are not tampered with. You
+must set a `cookieSigningSecret` string in your configuration
+```yaml
+gatewayCookieConfiguration:
+ enabled: true
+ cookieSigningSecret: "ahighentropystring"
+```
+when making use of this feature. If you load balance request across multiple Trino Gateway
+instances, ensure each instance has the same `cookieSigningSecret`.
+
+The Trino Gateway handles standard Trino OAuth2 handshakes with no additional
+configuration. If you are using a customized or commercial Trino distribution, then
+the paths used to define the OAuth handshake may be modified.
+
+`routingPaths`: If the request URI starts with a path in this list, then
+* If no cookie is present, add a routing cookie
+* If a cookie is present, route the request to the backend defined by that cookie
+
+`deletePaths`: If the request URI starts with a path in this list,
+return a response that instructs the client to delete the cookie.
+
+Additionally, the `lifetime` property sets the duration for which a cookie remains in
+effect after creation. Ensure that it is greater than
+the time required to complete the handshake. Default `lifetime` is 10 minutes.
+
+These properties are defined under the `oauth2GatewayCookieConfiguration` node:
+
+```yaml
+oauth2GatewayCookieConfiguration:
+ routingPaths:
+ - "/oauth2"
+ - "/custom/oauth2/callback"
+ - "/alternative/oauth2/initiate"
+ deletePaths:
+ - "/custom/logout"
+ lifetime: "5m"
+```
diff --git a/docs/routing-rules.md b/docs/routing-rules.md
index 006ec19d7..e7a3733ef 100644
--- a/docs/routing-rules.md
+++ b/docs/routing-rules.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/docs/security.md b/docs/security.md
index 4746d712d..147543e7a 100644
--- a/docs/security.md
+++ b/docs/security.md
@@ -18,6 +18,7 @@
Gateway API
Resource groups API
Routing rules
+ Routing logic
diff --git a/gateway-ha/pom.xml b/gateway-ha/pom.xml
index 5f58ba9ef..4e5a0aa36 100644
--- a/gateway-ha/pom.xml
+++ b/gateway-ha/pom.xml
@@ -126,6 +126,12 @@
stats
+
+ io.airlift
+ units
+ 1.10
+
+
io.dropwizard
dropwizard-assets
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfiguration.java
new file mode 100644
index 000000000..e19133596
--- /dev/null
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfiguration.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.gateway.ha.config;
+
+import javax.crypto.SecretKey;
+import javax.crypto.spec.SecretKeySpec;
+
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+public class GatewayCookieConfiguration
+{
+ private SecretKey cookieSigningKey;
+ private boolean enabled;
+
+ public boolean isEnabled()
+ {
+ return enabled;
+ }
+
+ public void setEnabled(boolean enabled)
+ {
+ this.enabled = enabled;
+ }
+
+ public SecretKey getCookieSigningKey()
+ {
+ return cookieSigningKey;
+ }
+
+ public void setCookieSigningSecret(String cookieSigningSecret)
+ {
+ cookieSigningKey = new SecretKeySpec(cookieSigningSecret.getBytes(UTF_8), "HmacSHA256");
+ }
+}
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfigurationPropertiesProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfigurationPropertiesProvider.java
new file mode 100644
index 000000000..e5f0dedee
--- /dev/null
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfigurationPropertiesProvider.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.gateway.ha.config;
+
+import javax.crypto.SecretKey;
+
+public class GatewayCookieConfigurationPropertiesProvider
+{
+ private static final GatewayCookieConfigurationPropertiesProvider instance = new GatewayCookieConfigurationPropertiesProvider();
+ private GatewayCookieConfiguration gatewayCookieConfiguration;
+
+ private GatewayCookieConfigurationPropertiesProvider()
+ {}
+
+ public void initialize(GatewayCookieConfiguration gatewayCookieConfiguration)
+ {
+ if (gatewayCookieConfiguration.isEnabled() && gatewayCookieConfiguration.getCookieSigningKey() == null) {
+ throw new IllegalArgumentException("gatewayCookieConfiguration.cookieSigningSecret must be provided when cookies are enabled");
+ }
+ this.gatewayCookieConfiguration = gatewayCookieConfiguration;
+ }
+
+ public static GatewayCookieConfigurationPropertiesProvider getInstance()
+ {
+ return instance;
+ }
+
+ public boolean isEnabled()
+ {
+ ensureInitialized();
+ return gatewayCookieConfiguration.isEnabled();
+ }
+
+ public SecretKey getCookieSigningKey()
+ {
+ ensureInitialized();
+ return gatewayCookieConfiguration.getCookieSigningKey();
+ }
+
+ private void ensureInitialized()
+ {
+ if (gatewayCookieConfiguration == null) {
+ throw new IllegalStateException("getInstance.initialize(GatewayCookieConfiguration) must be called before use");
+ }
+ }
+}
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java
index a0ec3d6b4..83bba6172 100644
--- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java
@@ -35,6 +35,8 @@ public class HaGatewayConfiguration
private BackendStateConfiguration backendState;
private ClusterStatsConfiguration clusterStatsConfiguration;
private List extraWhitelistPaths = new ArrayList<>();
+ private OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration = new OAuth2GatewayCookieConfiguration();
+ private GatewayCookieConfiguration gatewayCookieConfiguration = new GatewayCookieConfiguration();
// List of Modules with FQCN (Fully Qualified Class Name)
private List modules;
@@ -164,6 +166,26 @@ public void setExtraWhitelistPaths(List extraWhitelistPaths)
this.extraWhitelistPaths = extraWhitelistPaths;
}
+ public OAuth2GatewayCookieConfiguration getOauth2GatewayCookieConfiguration()
+ {
+ return oauth2GatewayCookieConfiguration;
+ }
+
+ public void setOauth2GatewayCookieConfiguration(OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration)
+ {
+ this.oauth2GatewayCookieConfiguration = oauth2GatewayCookieConfiguration;
+ }
+
+ public GatewayCookieConfiguration getGatewayCookieConfiguration()
+ {
+ return gatewayCookieConfiguration;
+ }
+
+ public void setGatewayCookieConfiguration(GatewayCookieConfiguration gatewayCookieConfiguration)
+ {
+ this.gatewayCookieConfiguration = gatewayCookieConfiguration;
+ }
+
public List getModules()
{
return this.modules;
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java
new file mode 100644
index 000000000..d5224a16e
--- /dev/null
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.gateway.ha.config;
+
+import com.google.common.collect.ImmutableList;
+import io.airlift.units.Duration;
+
+import java.util.List;
+
+public class OAuth2GatewayCookieConfiguration
+{
+ // Configuration initialization using dropwizard requires
+ // instance method setters. Values are global, and can be accessed using static getters
+ private List routingPaths = ImmutableList.of("/oauth2");
+ private List deletePaths = ImmutableList.of("/logout", "/oauth2/logout");
+ private Duration lifetime = Duration.valueOf("10m");
+
+ public List getDeletePaths()
+ {
+ return deletePaths;
+ }
+
+ public void setDeletePaths(List deletePaths)
+ {
+ this.deletePaths = deletePaths;
+ }
+
+ public List getRoutingPaths()
+ {
+ return routingPaths;
+ }
+
+ public void setRoutingPaths(List routingPaths)
+ {
+ this.routingPaths = routingPaths;
+ }
+
+ public Duration getLifetime()
+ {
+ return lifetime;
+ }
+
+ public void setLifetime(String lifetime)
+ {
+ this.lifetime = Duration.valueOf(lifetime);
+ }
+}
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java
new file mode 100644
index 000000000..dfbd14560
--- /dev/null
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.gateway.ha.config;
+
+import io.airlift.units.Duration;
+
+import java.util.List;
+
+public class OAuth2GatewayCookieConfigurationPropertiesProvider
+{
+ private static final OAuth2GatewayCookieConfigurationPropertiesProvider instance = new OAuth2GatewayCookieConfigurationPropertiesProvider();
+
+ private OAuth2GatewayCookieConfiguration oAuth2GatewayCookieConfiguration;
+
+ private OAuth2GatewayCookieConfigurationPropertiesProvider()
+ {}
+
+ public static OAuth2GatewayCookieConfigurationPropertiesProvider getInstance()
+ {
+ return instance;
+ }
+
+ public void initialize(OAuth2GatewayCookieConfiguration oAuth2GatewayCookieConfiguration)
+ {
+ this.oAuth2GatewayCookieConfiguration = oAuth2GatewayCookieConfiguration;
+ }
+
+ public List getDeletePaths()
+ {
+ ensureInitialized();
+ return oAuth2GatewayCookieConfiguration.getDeletePaths();
+ }
+
+ public List getRoutingPaths()
+ {
+ ensureInitialized();
+ return oAuth2GatewayCookieConfiguration.getRoutingPaths();
+ }
+
+ public Duration getLifetime()
+ {
+ ensureInitialized();
+ return oAuth2GatewayCookieConfiguration.getLifetime();
+ }
+
+ private void ensureInitialized()
+ {
+ if (oAuth2GatewayCookieConfiguration == null) {
+ throw new IllegalStateException("getInstance.initialize(OAuth2GatewayCookieConfiguration) must be called before use");
+ }
+ }
+}
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java
index 792496c30..4a401a35e 100644
--- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java
@@ -15,13 +15,18 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableList;
import com.google.common.io.CharStreams;
import io.airlift.log.Logger;
+import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
+import io.trino.gateway.ha.router.GatewayCookie;
+import io.trino.gateway.ha.router.OAuth2GatewayCookie;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingGroupSelector;
import io.trino.gateway.ha.router.RoutingManager;
import io.trino.gateway.proxyserver.ProxyHandler;
import io.trino.gateway.proxyserver.wrapper.MultiReadHttpServletRequest;
+import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.ws.rs.HttpMethod;
@@ -32,9 +37,11 @@
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
+import java.util.Arrays;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
+import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -84,6 +91,7 @@ public class QueryIdCachingProxyHandler
private final ProxyHandlerStats proxyHandlerStats;
private final List extraWhitelistPaths;
private final String applicationEndpoint;
+ private final boolean cookiesEnabled;
public QueryIdCachingProxyHandler(
QueryHistoryManager queryHistoryManager,
@@ -99,6 +107,7 @@ public QueryIdCachingProxyHandler(
this.queryHistoryManager = queryHistoryManager;
this.extraWhitelistPaths = extraWhitelistPaths;
this.applicationEndpoint = "http://localhost:" + serverApplicationPort;
+ cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled();
}
protected static String extractQueryIdIfPresent(String path, String queryParams)
@@ -269,9 +278,23 @@ private boolean isPathWhiteListed(String path)
|| extraWhitelistPaths.stream().anyMatch(s -> path.startsWith(s));
}
- public boolean handleAuthRequest(HttpServletRequest request)
+ @Override
+ public List generateDeleteCookieList(HttpServletRequest clientRequest)
{
- return true;
+ if (!cookiesEnabled || clientRequest.getCookies() == null) {
+ return ImmutableList.of();
+ }
+
+ return Arrays.stream(clientRequest.getCookies())
+ .filter(c -> c.getName().startsWith(GatewayCookie.PREFIX))
+ .map(GatewayCookie::fromCookie)
+ .filter(c -> !c.isValid() || c.matchesDeletePath(clientRequest.getRequestURI()))
+ .map(GatewayCookie::toCookie)
+ .peek(c -> {
+ c.setValue("delete");
+ c.setMaxAge(0);
+ })
+ .toList();
}
@Override
@@ -281,10 +304,10 @@ public String rewriteTarget(HttpServletRequest request)
return buildUriWithNewBackend(applicationEndpoint, request);
}
- if (isUseStickyRouting(request)) {
- String backend = getPreviousBackend(request);
- logRewrite(backend, request);
- return buildUriWithNewBackend(backend, request);
+ Optional previousBackend = getPreviousBackend(request);
+ if (previousBackend.isPresent()) {
+ logRewrite(previousBackend.orElseThrow(), request);
+ return previousBackend.map(b -> buildUriWithNewBackend(b, request)).orElseThrow();
}
String backend = getBackendFromRoutingGroup(request);
@@ -322,21 +345,27 @@ private String getBackendFromRoutingGroup(HttpServletRequest request)
return routingManager.provideAdhocBackend(user);
}
- private String getPreviousBackend(HttpServletRequest request)
+ private Optional getPreviousBackend(HttpServletRequest request)
{
String queryId = extractQueryIdIfPresent(request);
-
- // Find query id and get url from cache
if (!isNullOrEmpty(queryId)) {
- return routingManager.findBackendForQueryId(queryId);
+ return Optional.of(routingManager.findBackendForQueryId(queryId));
+ }
+ if (cookiesEnabled && request.getCookies() != null) {
+ List cookies = Arrays.stream(request.getCookies())
+ .filter(c -> c.getName().startsWith(GatewayCookie.PREFIX))
+ .map(GatewayCookie::fromCookie)
+ .filter(GatewayCookie::isValid)
+ .filter(c -> !isNullOrEmpty(c.getBackend()))
+ .filter(c -> c.matchesRoutingPath(request.getRequestURI()))
+ .sorted()
+ .toList();
+ if (!cookies.isEmpty()) {
+ return Optional.of(cookies.getFirst().getBackend());
+ }
}
- log.error("No backend found for queryId %s", queryId);
- return getBackendFromRoutingGroup(request);
- }
- private boolean isUseStickyRouting(HttpServletRequest request)
- {
- return !isNullOrEmpty(extractQueryIdIfPresent(request));
+ return Optional.empty();
}
@Override
@@ -349,50 +378,14 @@ protected void postConnectionHook(
Callback callback)
{
try {
- String requestPath = request.getRequestURI();
- if (requestPath.startsWith(V1_STATEMENT_PATH)
- && request.getMethod().equals(HttpMethod.POST)) {
- String output;
- boolean isGZipEncoding = isGZipEncoding(response);
- if (isGZipEncoding) {
- output = plainTextFromGz(buffer);
- }
- else {
- output = new String(buffer);
- }
- log.debug("For Request [%s] got Response output [%s]", request.getRequestURI(), output);
-
- QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request);
- log.debug("Extracting Proxy destination : [%s] for request : [%s]",
- queryDetail.getBackendUrl(), request.getRequestURI());
-
- if (response.getStatus() == HttpStatus.OK_200) {
- HashMap results = OBJECT_MAPPER.readValue(output, HashMap.class);
- queryDetail.setQueryId(results.get("id"));
-
- if (!isNullOrEmpty(queryDetail.getQueryId())) {
- routingManager.setBackendForQueryId(
- queryDetail.getQueryId(), queryDetail.getBackendUrl());
- log.debug(
- "QueryId [%s] mapped with proxy [%s]",
- queryDetail.getQueryId(),
- queryDetail.getBackendUrl());
- }
- else {
- log.debug("QueryId [%s] could not be cached", queryDetail.getQueryId());
- }
- }
- else {
- log.error(
- "Non OK HTTP Status code with response [%s] , Status code [%s]",
- output,
- response.getStatus());
- }
- // Saving history at gateway.
- queryHistoryManager.submitQueryDetail(queryDetail);
+ if (request.getRequestURI().startsWith(V1_STATEMENT_PATH) && request.getMethod().equals(HttpMethod.POST)) {
+ recordBackendForQueryId(request, response, buffer);
}
- else {
- log.debug("SKIPPING For %s", requestPath);
+ else if (cookiesEnabled && request.getRequestURI().startsWith(OAuth2GatewayCookie.OAUTH2_PATH)
+ && !(request.getCookies() != null
+ && Arrays.stream(request.getCookies()).anyMatch(c -> c.getName().equals(OAuth2GatewayCookie.NAME)))) {
+ GatewayCookie oauth2Cookie = new OAuth2GatewayCookie(request.getHeader(PROXY_TARGET_HEADER));
+ response.addCookie(oauth2Cookie.toCookie());
}
}
catch (Exception e) {
@@ -401,6 +394,42 @@ protected void postConnectionHook(
super.postConnectionHook(request, response, buffer, offset, length, callback);
}
+ void recordBackendForQueryId(HttpServletRequest request, HttpServletResponse response, byte[] buffer)
+ throws IOException
+ {
+ String output;
+ boolean isGZipEncoding = isGZipEncoding(response);
+ if (isGZipEncoding) {
+ output = plainTextFromGz(buffer);
+ }
+ else {
+ output = new String(buffer);
+ }
+ log.debug("For Request [%s] got Response output [%s]", request.getRequestURI(), output);
+
+ QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request);
+
+ if (queryDetail.getBackendUrl() == null) {
+ log.error("Server response to request %s does not contain proxytarget header", request.getRequestURI());
+ }
+ log.debug("Extracting Proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getRequestURI());
+
+ if (response.getStatus() == HttpStatus.OK_200) {
+ HashMap results = OBJECT_MAPPER.readValue(output, HashMap.class);
+ queryDetail.setQueryId(results.get("id"));
+
+ if (!isNullOrEmpty(queryDetail.getQueryId())) {
+ routingManager.setBackendForQueryId(queryDetail.getQueryId(), queryDetail.getBackendUrl());
+ log.debug("QueryId [%s] mapped with proxy [%s]", queryDetail.getQueryId(), queryDetail.getBackendUrl());
+ }
+ }
+ else {
+ log.error("Non OK HTTP Status code with response [%s] , Status code [%s]", output, response.getStatus());
+ }
+ // Save history in Trino Gateway.
+ queryHistoryManager.submitQueryDetail(queryDetail);
+ }
+
private QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(HttpServletRequest request)
throws IOException
{
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java
index df8a5239d..8ee294d5a 100644
--- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java
@@ -27,7 +27,9 @@
import io.dropwizard.jetty.HttpConnectorFactory;
import io.trino.gateway.ha.config.AuthenticationConfiguration;
import io.trino.gateway.ha.config.AuthorizationConfiguration;
+import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.HaGatewayConfiguration;
+import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider;
import io.trino.gateway.ha.config.RequestRouterConfiguration;
import io.trino.gateway.ha.config.RoutingRulesConfiguration;
import io.trino.gateway.ha.config.UserConfiguration;
@@ -86,6 +88,12 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration, Environment
authenticationFilter = getAuthFilter(configuration);
backendStateConnectionManager = new BackendStateManager();
extraWhitelistPaths = configuration.getExtraWhitelistPaths();
+
+ GatewayCookieConfigurationPropertiesProvider gatewayCookieConfigurationPropertiesProvider = GatewayCookieConfigurationPropertiesProvider.getInstance();
+ gatewayCookieConfigurationPropertiesProvider.initialize(configuration.getGatewayCookieConfiguration());
+
+ OAuth2GatewayCookieConfigurationPropertiesProvider oAuth2GatewayCookieConfigurationPropertiesProvider = OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance();
+ oAuth2GatewayCookieConfigurationPropertiesProvider.initialize(configuration.getOauth2GatewayCookieConfiguration());
}
private LbOAuthManager getOAuthManager(HaGatewayConfiguration configuration)
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java
new file mode 100644
index 000000000..3e86f2224
--- /dev/null
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java
@@ -0,0 +1,289 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.gateway.ha.router;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyOrder;
+import com.google.common.hash.Hashing;
+import io.airlift.json.JsonCodec;
+import io.airlift.log.Logger;
+import io.airlift.units.Duration;
+import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider;
+import jakarta.servlet.http.Cookie;
+
+import java.util.Base64;
+import java.util.List;
+
+import static com.google.common.base.Strings.isNullOrEmpty;
+import static java.nio.charset.StandardCharsets.UTF_8;
+import static java.util.Objects.requireNonNull;
+
+@JsonPropertyOrder(alphabetic = true)
+public class GatewayCookie
+ implements Comparable
+{
+ private static final Logger log = Logger.get(GatewayCookie.class);
+ private String signature;
+ private final UnsignedGatewayCookie unsignedGatewayCookie;
+ private final GatewayCookieConfigurationPropertiesProvider gatewayCookieConfigurationPropertiesProvider = GatewayCookieConfigurationPropertiesProvider.getInstance();
+
+ public static final String PREFIX = "TG.";
+
+ public static final JsonCodec CODEC = JsonCodec.jsonCodec(GatewayCookie.class);
+
+ @JsonCreator
+ public GatewayCookie(
+ @JsonProperty("ts") Long ts,
+ @JsonProperty("name") String name,
+ @JsonProperty("payload") String payload,
+ @JsonProperty("routingPaths") List routingPaths,
+ @JsonProperty("deletePaths") List deletePaths,
+ @JsonProperty("backend") String backend,
+ @JsonProperty("priority") Integer priority,
+ @JsonProperty("ttl") Duration ttl,
+ @JsonProperty("signature") String signature)
+ {
+ this.unsignedGatewayCookie = new UnsignedGatewayCookie(
+ requireNonNull(ts),
+ requireNonNull(name),
+ payload,
+ backend,
+ requireNonNull(routingPaths),
+ requireNonNull(deletePaths),
+ requireNonNull(ttl),
+ priority);
+ this.signature = signature;
+ }
+
+ public GatewayCookie(String name, String payload, String backend, List routingPaths, List deletePaths, Duration ttl, int priority)
+ {
+ this.unsignedGatewayCookie = new UnsignedGatewayCookie(
+ System.currentTimeMillis(),
+ requireNonNull(name.startsWith(PREFIX) ? name : PREFIX + name),
+ payload,
+ backend,
+ requireNonNull(routingPaths),
+ requireNonNull(deletePaths),
+ requireNonNull(ttl),
+ priority);
+ signature = computeSignature();
+ }
+
+ @JsonProperty
+ public Long getTs()
+ {
+ return unsignedGatewayCookie.getTs();
+ }
+
+ @JsonProperty
+ public String getName()
+ {
+ return unsignedGatewayCookie.getName();
+ }
+
+ @JsonProperty
+ public String getPayload()
+ {
+ return unsignedGatewayCookie.getPayload();
+ }
+
+ @JsonProperty
+ public String getBackend()
+ {
+ return unsignedGatewayCookie.getBackend();
+ }
+
+ @JsonProperty
+ public int getPriority()
+ {
+ return unsignedGatewayCookie.getPriority();
+ }
+
+ @JsonProperty
+ public List getRoutingPaths()
+ {
+ return unsignedGatewayCookie.getRoutingPaths();
+ }
+
+ @JsonProperty
+ public List getDeletePaths()
+ {
+ return unsignedGatewayCookie.getDeletePaths();
+ }
+
+ @JsonProperty
+ public Duration getTtl()
+ {
+ return unsignedGatewayCookie.getTtl();
+ }
+
+ @JsonProperty
+ public String getSignature()
+ {
+ return signature;
+ }
+
+ public void setTs(Long ts)
+ {
+ unsignedGatewayCookie.setTs(ts);
+ }
+
+ private String computeSignature()
+ {
+ return Hashing.hmacSha256(gatewayCookieConfigurationPropertiesProvider.getCookieSigningKey())
+ .hashString(UnsignedGatewayCookie.CODEC.toJson(unsignedGatewayCookie), UTF_8)
+ .toString();
+ }
+
+ @Override
+ public int compareTo(GatewayCookie o)
+ {
+ int priorityDelta = unsignedGatewayCookie.getPriority() - o.getPriority();
+ return priorityDelta != 0 ? priorityDelta : (int) (unsignedGatewayCookie.getTs() - o.getTs());
+ }
+
+ public Cookie toCookie()
+ {
+ Cookie cookie = new Cookie(unsignedGatewayCookie.getName(), Base64.getUrlEncoder().encodeToString(CODEC.toJson(this).getBytes(UTF_8)));
+ cookie.setMaxAge((int) unsignedGatewayCookie.getTtl().toMillis() / 1000);
+ return cookie;
+ }
+
+ public static GatewayCookie fromCookie(Cookie cookie)
+ {
+ return GatewayCookie.CODEC.fromJson(Base64.getUrlDecoder().decode(cookie.getValue()));
+ }
+
+ public boolean matchesRoutingPath(String path)
+ {
+ if (matchesDeletePath(path)) {
+ return false;
+ }
+
+ return unsignedGatewayCookie.getRoutingPaths().stream().anyMatch(path::startsWith);
+ }
+
+ public boolean matchesDeletePath(String path)
+ {
+ return unsignedGatewayCookie.getDeletePaths().contains(path);
+ }
+
+ public boolean isValid()
+ {
+ if (System.currentTimeMillis() > unsignedGatewayCookie.getTs() + unsignedGatewayCookie.getTtl().toMillis()) {
+ return false;
+ }
+
+ if (isNullOrEmpty(signature) || !signature.equals(computeSignature())) {
+ log.error("Invalid cookie: %s", CODEC.toJson(this));
+ throw new IllegalArgumentException("Invalid cookie signature");
+ }
+
+ return true;
+ }
+
+ @JsonPropertyOrder(alphabetic = true)
+ public static class UnsignedGatewayCookie
+ {
+ public static final JsonCodec CODEC = JsonCodec.jsonCodec(UnsignedGatewayCookie.class);
+ private Long ts; // timestamp. The shortened name saves 8 bytes of cookie size
+ private final String name;
+ private final String payload;
+ private final List routingPaths;
+
+ private final List deletePaths;
+
+ private final int priority;
+ private final Duration ttl;
+ private final String backend;
+
+ public UnsignedGatewayCookie(GatewayCookie gatewayCookie)
+ {
+ this.ts = gatewayCookie.getTs();
+ this.name = gatewayCookie.getName();
+ this.payload = gatewayCookie.getPayload();
+ this.routingPaths = gatewayCookie.getRoutingPaths();
+ this.deletePaths = gatewayCookie.getDeletePaths();
+ this.priority = gatewayCookie.getPriority();
+ this.ttl = gatewayCookie.getTtl();
+ this.backend = gatewayCookie.getBackend();
+ }
+
+ public UnsignedGatewayCookie(Long ts, String name, String payload, String backend, List routingPaths, List deletePaths, Duration ttl, int priority)
+ {
+ this.name = name.startsWith(PREFIX) ? name : PREFIX + name;
+ this.payload = payload;
+ this.backend = backend;
+ this.routingPaths = routingPaths;
+ this.deletePaths = deletePaths;
+ this.ttl = ttl;
+ this.ts = ts;
+ this.priority = priority;
+ }
+
+ @JsonProperty
+ public Long getTs()
+ {
+ return ts;
+ }
+
+ public void setTs(Long ts)
+ {
+ this.ts = ts;
+ }
+
+ @JsonProperty
+ public String getName()
+ {
+ return name;
+ }
+
+ @JsonProperty
+ public String getPayload()
+ {
+ return payload;
+ }
+
+ @JsonProperty
+ public String getBackend()
+ {
+ return backend;
+ }
+
+ @JsonProperty
+ public int getPriority()
+ {
+ return priority;
+ }
+
+ @JsonProperty
+ public List getRoutingPaths()
+ {
+ return routingPaths;
+ }
+
+ @JsonProperty
+ public List getDeletePaths()
+ {
+ return deletePaths;
+ }
+
+ @JsonProperty
+ public Duration getTtl()
+ {
+ return ttl;
+ }
+ }
+}
diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java
new file mode 100644
index 000000000..47be10ef7
--- /dev/null
+++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package io.trino.gateway.ha.router;
+
+import com.google.common.collect.ImmutableList;
+import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider;
+
+public class OAuth2GatewayCookie
+ extends GatewayCookie
+{
+ public static final String NAME = GatewayCookie.PREFIX + "OAUTH2";
+ public static final String OAUTH2_PATH = "/oauth2";
+
+ public OAuth2GatewayCookie(String backend)
+ {
+ super(
+ NAME,
+ null,
+ backend,
+ ImmutableList.of(OAUTH2_PATH),
+ OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths(),
+ OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getLifetime(),
+ 0);
+ }
+}
diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java
index 088536940..bd1a94bdd 100644
--- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java
+++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java
@@ -13,7 +13,9 @@
*/
package io.trino.gateway.proxyserver;
+import com.google.common.collect.ImmutableList;
import io.airlift.log.Logger;
+import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import org.eclipse.jetty.client.api.Request;
@@ -27,6 +29,7 @@
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Enumeration;
+import java.util.List;
import java.util.zip.GZIPInputStream;
/* Order of control => rewriteTarget, preConnectionHook, postConnectionHook. */
@@ -114,4 +117,9 @@ protected boolean isCompressed(final byte[] compressed)
return (compressed[0] == (byte) GZIPInputStream.GZIP_MAGIC)
&& (compressed[1] == (byte) (GZIPInputStream.GZIP_MAGIC >> 8));
}
+
+ public List generateDeleteCookieList(HttpServletRequest clientRequest)
+ {
+ return ImmutableList.of();
+ }
}
diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java
index 5267d7c89..da013d613 100644
--- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java
+++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java
@@ -98,6 +98,16 @@ protected String rewriteTarget(HttpServletRequest request)
return target;
}
+ @Override
+ protected void onServerResponseHeaders(
+ HttpServletRequest clientRequest,
+ HttpServletResponse proxyResponse,
+ Response serverResponse)
+ {
+ this.proxyHandler.generateDeleteCookieList(clientRequest).forEach(proxyResponse::addCookie);
+ super.onServerResponseHeaders(clientRequest, proxyResponse, serverResponse);
+ }
+
/**
* Customize the response returned from remote server.
*/
diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java b/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java
index 8cc8a6a75..69f42be65 100644
--- a/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java
+++ b/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java
@@ -19,8 +19,10 @@
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
+import okhttp3.mockwebserver.Dispatcher;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
import org.jdbi.v3.core.Handle;
import org.jdbi.v3.core.Jdbi;
@@ -28,6 +30,7 @@
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
+import java.util.Map;
import java.util.Random;
import java.util.Scanner;
@@ -66,6 +69,23 @@ public static void prepareMockBackend(
.setResponseCode(200));
}
+ public static void setPathSpecificResponses(
+ MockWebServer backend, Map pathResponseMap)
+ {
+ Dispatcher dispatcher = new Dispatcher()
+ {
+ @Override
+ public MockResponse dispatch(RecordedRequest request)
+ {
+ if (pathResponseMap.containsKey(request.getPath())) {
+ return new MockResponse().setResponseCode(200).setBody(pathResponseMap.get(request.getPath()));
+ }
+ return new MockResponse().setResponseCode(404);
+ }
+ };
+ backend.setDispatcher(dispatcher);
+ }
+
public static TestConfig buildGatewayConfigAndSeedDb(int routerPort, String configFile)
throws IOException
{
diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java
index a213acec5..b123415e7 100644
--- a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java
+++ b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java
@@ -14,14 +14,17 @@
package io.trino.gateway.ha;
import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableMap;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
+import io.trino.gateway.ha.router.GatewayCookie;
+import io.trino.gateway.ha.router.OAuth2GatewayCookie;
+import okhttp3.Cookie;
import okhttp3.MediaType;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.mockwebserver.MockWebServer;
-import okhttp3.mockwebserver.RecordedRequest;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
@@ -29,6 +32,12 @@
import org.junit.jupiter.api.TestInstance.Lifecycle;
import org.testcontainers.containers.TrinoContainer;
+import java.io.IOException;
+import java.util.Base64;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
+
import static org.assertj.core.api.Assertions.assertThat;
@TestInstance(Lifecycle.PER_CLASS)
@@ -40,6 +49,11 @@ public class TestGatewayHaMultipleBackend
private TrinoContainer adhocTrino;
private TrinoContainer scheduledTrino;
+ public static String oauthInitiatePath = OAuth2GatewayCookie.OAUTH2_PATH;
+ public static String oauthCallbackPath = oauthInitiatePath + "/callback";
+ public static String oauthInitialResponse = "abc";
+ public static String oauthCallbackResponse = "xyz";
+
final int routerPort = 20000 + (int) (Math.random() * 1000);
final int customBackendPort = 21000 + (int) (Math.random() * 1000);
@@ -59,7 +73,11 @@ public void setup()
int backend1Port = adhocTrino.getMappedPort(8080);
int backend2Port = scheduledTrino.getMappedPort(8080);
- HaGatewayTestUtils.prepareMockBackend(customBackend, customBackendPort, CUSTOM_RESPONSE);
+ HaGatewayTestUtils.prepareMockBackend(customBackend, customBackendPort, "default custom response");
+ HaGatewayTestUtils.setPathSpecificResponses(customBackend, ImmutableMap.of(
+ oauthInitiatePath, oauthInitialResponse,
+ oauthCallbackPath, oauthCallbackResponse,
+ CUSTOM_PATH, CUSTOM_RESPONSE));
// seed database
HaGatewayTestUtils.TestConfig testConfig =
@@ -93,9 +111,6 @@ public void testCustomPath()
.build();
Response response1 = httpClient.newCall(request1).execute();
assertThat(response1.body().string()).isEqualTo(CUSTOM_RESPONSE);
- RecordedRequest recordedRequest = customBackend.takeRequest();
- assertThat(recordedRequest.getMethod()).isEqualTo("POST");
- assertThat(recordedRequest.getPath()).isEqualTo(CUSTOM_PATH);
Request request2 =
new Request.Builder()
@@ -162,6 +177,101 @@ public void testBackendConfiguration()
assertThat(backendConfiguration[2].getExternalUrl()).isEqualTo("externalUrl");
}
+ @Test
+ public void testCookieBasedRouting()
+ throws IOException
+ {
+ // This simulates the Trino oauth handshake
+ OkHttpClient httpClient = new OkHttpClient.Builder()
+ .connectTimeout(10, TimeUnit.SECONDS)
+ .writeTimeout(10, TimeUnit.SECONDS)
+ .readTimeout(60, TimeUnit.SECONDS)
+ .build();
+ String oauthInitiateBody = "anything";
+ RequestBody requestBody =
+ RequestBody.create(
+ MediaType.parse("application/json; charset=utf-8"), oauthInitiateBody);
+
+ Request initiateRequest =
+ new Request.Builder()
+ .url("http://localhost:" + routerPort + oauthInitiatePath)
+ .post(requestBody)
+ .addHeader("X-Trino-Routing-Group", "custom")
+ .build();
+ Response initiateResponse = httpClient.newCall(initiateRequest).execute();
+ assertThat(initiateResponse.header("set-cookie")).isNotEmpty();
+
+ Request callbackRequest =
+ new Request.Builder()
+ .url("http://localhost:" + routerPort + oauthCallbackPath)
+ .post(requestBody)
+ .addHeader("Cookie", initiateResponse.header("set-cookie"))
+ .build();
+ Response callbackResponse = httpClient.newCall(callbackRequest).execute();
+ assertThat(callbackResponse.body().string()).isEqualTo(oauthCallbackResponse);
+
+ Request logoutRequest =
+ new Request.Builder()
+ .url("http://localhost:" + routerPort + "/custom/logout")
+ .post(requestBody)
+ .addHeader("Cookie", initiateResponse.header("set-cookie"))
+ .build();
+ Response logoutResponse = httpClient.newCall(logoutRequest).execute();
+
+ List cookies = Cookie.parseAll(logoutResponse.request().url(), logoutResponse.headers());
+ Optional cookie = cookies.stream().filter(c -> c.name().equals(OAuth2GatewayCookie.NAME)).findAny();
+ assertThat(cookie).isNotEmpty();
+ assertThat(cookie.orElseThrow().value()).isEqualTo("delete");
+ // expires-at has been deprecated in favor of max-age. However, okhttp3 does not expose a max-age property,
+ // but instead sets expires-at to Long.MIN_VALUE when max-age is set to 0
+ // https://github.com/square/okhttp/blob/577d621585f7525d3e98a9161bc26d2965686538/okhttp/src/main/kotlin/okhttp3/Cookie.kt#L673
+ assertThat(cookie.orElseThrow().expiresAt()).isEqualTo(Long.MIN_VALUE);
+ }
+
+ @Test
+ public void testCookieSigning()
+ throws IOException
+ {
+ OkHttpClient httpClient = new OkHttpClient.Builder()
+ .connectTimeout(10, TimeUnit.SECONDS)
+ .writeTimeout(10, TimeUnit.SECONDS)
+ .readTimeout(60, TimeUnit.SECONDS)
+ .build();
+ String oauthInitiateBody = "anything";
+ RequestBody requestBody =
+ RequestBody.create(
+ MediaType.parse("application/json; charset=utf-8"), oauthInitiateBody);
+
+ Request initiateRequest =
+ new Request.Builder()
+ .url("http://localhost:" + routerPort + oauthInitiatePath)
+ .post(requestBody)
+ .addHeader("X-Trino-Routing-Group", "custom")
+ .build();
+ Response initiateResponse = httpClient.newCall(initiateRequest).execute();
+ String cookieHeader = initiateResponse.header("set-cookie");
+ assertThat(cookieHeader).isNotEmpty();
+ List cookies = Cookie.parseAll(initiateResponse.request().url(), initiateResponse.headers());
+ Optional cookie = cookies.stream().filter(c -> c.name().equals(OAuth2GatewayCookie.NAME)).findAny();
+ assertThat(cookie).isNotEmpty();
+
+ GatewayCookie gatewayCookie = GatewayCookie.CODEC.fromJson(Base64.getUrlDecoder().decode(cookie.orElseThrow().value()));
+ assertThat(gatewayCookie.getSignature()).isNotEmpty();
+
+ // Tamper with values. This will cause the cookie to be ignored because its values will not match the signature,
+ // causing the request will be routed to the adhoc backend
+ gatewayCookie.setTs(gatewayCookie.getTs() + 1000);
+ jakarta.servlet.http.Cookie tamperedCookie = gatewayCookie.toCookie();
+ Request callbackRequest =
+ new Request.Builder()
+ .url("http://localhost:" + routerPort + oauthCallbackPath)
+ .post(requestBody)
+ .addHeader("Cookie", String.format("%s=%s", tamperedCookie.getName(), tamperedCookie.getValue()))
+ .build();
+ Response callbackResponse = httpClient.newCall(callbackRequest).execute();
+ assertThat(callbackResponse.code()).isEqualTo(500);
+ }
+
@AfterAll
public void cleanup()
{
diff --git a/gateway-ha/src/test/resources/test-config-template.yml b/gateway-ha/src/test/resources/test-config-template.yml
index 8983890ed..4781e179a 100644
--- a/gateway-ha/src/test/resources/test-config-template.yml
+++ b/gateway-ha/src/test/resources/test-config-template.yml
@@ -26,5 +26,13 @@ managedApps:
extraWhitelistPaths:
- "/v1/custom"
+gatewayCookieConfiguration:
+ enabled: true
+ cookieSigningSecret: "kjlhbfrewbyuo452cds3dc1234ancdsjh"
+
+oauth2GatewayCookieConfiguration:
+ deletePaths:
+ - "/custom/logout"
+
logging:
type: external