diff --git a/docs/assets/gateway-routing-flow.svg b/docs/assets/gateway-routing-flow.svg new file mode 100644 index 000000000..07db084f8 --- /dev/null +++ b/docs/assets/gateway-routing-flow.svg @@ -0,0 +1,1665 @@ + + + + + + + + + + + + + Request + Client + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Send request to random member of routing group + + Send request to least loaded member of Routing Group + Is Adaptive Load Balancing enabled? + Select "adhoc" routing group + Select Routing Group from highest priority matched rule + Does request match a Routing Rule? + Is request related to a previous request? + Send to same backend as previous + + + + + diff --git a/docs/design.md b/docs/design.md index 91befeaa8..5cf2e492d 100644 --- a/docs/design.md +++ b/docs/design.md @@ -18,7 +18,8 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • - +
  • Routing logic
  • + diff --git a/docs/gateway-api.md b/docs/gateway-api.md index 81388107b..558a8da1e 100644 --- a/docs/gateway-api.md +++ b/docs/gateway-api.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/installation.md b/docs/installation.md index 0fee2b91e..137ad453e 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/operation.md b/docs/operation.md index e0c167f9e..c107899b0 100644 --- a/docs/operation.md +++ b/docs/operation.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/quickstart.md b/docs/quickstart.md index 73a97240c..d5730e443 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/references.md b/docs/references.md index 884914642..d8df9238d 100644 --- a/docs/references.md +++ b/docs/references.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/release-notes.md b/docs/release-notes.md index b9c04da0d..78cc3a6a5 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/resource-groups-api.md b/docs/resource-groups-api.md index 330e93616..9e4680f8b 100644 --- a/docs/resource-groups-api.md +++ b/docs/resource-groups-api.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/routing-logic.md b/docs/routing-logic.md new file mode 100644 index 000000000..a3c3ed18b --- /dev/null +++ b/docs/routing-logic.md @@ -0,0 +1,111 @@ +**Trino Gateway documentation** + + + + + + + + +
    + + + + + + + +
    + +# Routing Logic + +## Overview + +Trino Gateway checks incoming requests to see if they're related to previous +ones it handled. If they are, then Trino Gateway sends them to the same backend +that dealt with the earlier requests. + +If it is a new request, the Trino Gateway refers to [Routing rules](routing-rules.md) +to decide which group of backends, called a 'Routing Group,' should handle it. +It then picks a backend from that Routing Group to handle the request using +either an adaptive or round-robin strategy. + +![Request Routing Flow](assets/gateway-routing-flow.svg) + +## Sticky routing + +A request related to an ongoing process, or to state maintained on a single +backend cluster, must be routed to that backend for proper handling. Two +mechanisms for identifying related requests are currently implemented. By default, +only routing based on query identifier is enabled. + +### Routing based on query identifier (default) + +When a query is initiated through the Trino Gateway, the query id will be +extracted from the response and mapped to the backend that provided the +response. Any subsequent request containing that query id will be forwarded +to that backend. For example, to retrieve query results, the trino client +polls a URI of the form +`v1/statement/executing/queryid/nonce/counter`. The Trino Gateway will extract +the queryid from this URI. + +### Routing based on cookies + +OAuth2 authentication requires that the same backend is used for each step of +the handshake. When `gatewayCookieConfiguration.enabled` is set to true, a cookie +will be added to requests made to paths beginning with `/oauth2` unless they already have +a cookie present, which is used to route further `/oauth2/*` requests to the correct backend. +Cookies are not added to requests to `v1/*` and other Trino endpoints. + +Trino Gateway signs its cookies to ensure that they are not tampered with. You +must set a `cookieSigningSecret` string in your configuration +```yaml +gatewayCookieConfiguration: + enabled: true + cookieSigningSecret: "ahighentropystring" +``` +when making use of this feature. If you load balance request across multiple Trino Gateway +instances, ensure each instance has the same `cookieSigningSecret`. + +The Trino Gateway handles standard Trino OAuth2 handshakes with no additional +configuration. If you are using a customized or commercial Trino distribution, then +the paths used to define the OAuth handshake may be modified. + +`routingPaths`: If the request URI starts with a path in this list, then +* If no cookie is present, add a routing cookie +* If a cookie is present, route the request to the backend defined by that cookie + +`deletePaths`: If the request URI starts with a path in this list, +return a response that instructs the client to delete the cookie. + +Additionally, the `lifetime` property sets the duration for which a cookie remains in +effect after creation. Ensure that it is greater than +the time required to complete the handshake. Default `lifetime` is 10 minutes. + +These properties are defined under the `oauth2GatewayCookieConfiguration` node: + +```yaml +oauth2GatewayCookieConfiguration: + routingPaths: + - "/oauth2" + - "/custom/oauth2/callback" + - "/alternative/oauth2/initiate" + deletePaths: + - "/custom/logout" + lifetime: "5m" +``` diff --git a/docs/routing-rules.md b/docs/routing-rules.md index 006ec19d7..e7a3733ef 100644 --- a/docs/routing-rules.md +++ b/docs/routing-rules.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/docs/security.md b/docs/security.md index 4746d712d..147543e7a 100644 --- a/docs/security.md +++ b/docs/security.md @@ -18,6 +18,7 @@
  • Gateway API
  • Resource groups API
  • Routing rules
  • +
  • Routing logic
  • diff --git a/gateway-ha/pom.xml b/gateway-ha/pom.xml index 5f58ba9ef..4e5a0aa36 100644 --- a/gateway-ha/pom.xml +++ b/gateway-ha/pom.xml @@ -126,6 +126,12 @@ stats + + io.airlift + units + 1.10 + + io.dropwizard dropwizard-assets diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfiguration.java new file mode 100644 index 000000000..e19133596 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfiguration.java @@ -0,0 +1,45 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.config; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; + +import static java.nio.charset.StandardCharsets.UTF_8; + +public class GatewayCookieConfiguration +{ + private SecretKey cookieSigningKey; + private boolean enabled; + + public boolean isEnabled() + { + return enabled; + } + + public void setEnabled(boolean enabled) + { + this.enabled = enabled; + } + + public SecretKey getCookieSigningKey() + { + return cookieSigningKey; + } + + public void setCookieSigningSecret(String cookieSigningSecret) + { + cookieSigningKey = new SecretKeySpec(cookieSigningSecret.getBytes(UTF_8), "HmacSHA256"); + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfigurationPropertiesProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfigurationPropertiesProvider.java new file mode 100644 index 000000000..e5f0dedee --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/GatewayCookieConfigurationPropertiesProvider.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.config; + +import javax.crypto.SecretKey; + +public class GatewayCookieConfigurationPropertiesProvider +{ + private static final GatewayCookieConfigurationPropertiesProvider instance = new GatewayCookieConfigurationPropertiesProvider(); + private GatewayCookieConfiguration gatewayCookieConfiguration; + + private GatewayCookieConfigurationPropertiesProvider() + {} + + public void initialize(GatewayCookieConfiguration gatewayCookieConfiguration) + { + if (gatewayCookieConfiguration.isEnabled() && gatewayCookieConfiguration.getCookieSigningKey() == null) { + throw new IllegalArgumentException("gatewayCookieConfiguration.cookieSigningSecret must be provided when cookies are enabled"); + } + this.gatewayCookieConfiguration = gatewayCookieConfiguration; + } + + public static GatewayCookieConfigurationPropertiesProvider getInstance() + { + return instance; + } + + public boolean isEnabled() + { + ensureInitialized(); + return gatewayCookieConfiguration.isEnabled(); + } + + public SecretKey getCookieSigningKey() + { + ensureInitialized(); + return gatewayCookieConfiguration.getCookieSigningKey(); + } + + private void ensureInitialized() + { + if (gatewayCookieConfiguration == null) { + throw new IllegalStateException("getInstance.initialize(GatewayCookieConfiguration) must be called before use"); + } + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java index a0ec3d6b4..83bba6172 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/HaGatewayConfiguration.java @@ -35,6 +35,8 @@ public class HaGatewayConfiguration private BackendStateConfiguration backendState; private ClusterStatsConfiguration clusterStatsConfiguration; private List extraWhitelistPaths = new ArrayList<>(); + private OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration = new OAuth2GatewayCookieConfiguration(); + private GatewayCookieConfiguration gatewayCookieConfiguration = new GatewayCookieConfiguration(); // List of Modules with FQCN (Fully Qualified Class Name) private List modules; @@ -164,6 +166,26 @@ public void setExtraWhitelistPaths(List extraWhitelistPaths) this.extraWhitelistPaths = extraWhitelistPaths; } + public OAuth2GatewayCookieConfiguration getOauth2GatewayCookieConfiguration() + { + return oauth2GatewayCookieConfiguration; + } + + public void setOauth2GatewayCookieConfiguration(OAuth2GatewayCookieConfiguration oauth2GatewayCookieConfiguration) + { + this.oauth2GatewayCookieConfiguration = oauth2GatewayCookieConfiguration; + } + + public GatewayCookieConfiguration getGatewayCookieConfiguration() + { + return gatewayCookieConfiguration; + } + + public void setGatewayCookieConfiguration(GatewayCookieConfiguration gatewayCookieConfiguration) + { + this.gatewayCookieConfiguration = gatewayCookieConfiguration; + } + public List getModules() { return this.modules; diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java new file mode 100644 index 000000000..d5224a16e --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfiguration.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.config; + +import com.google.common.collect.ImmutableList; +import io.airlift.units.Duration; + +import java.util.List; + +public class OAuth2GatewayCookieConfiguration +{ + // Configuration initialization using dropwizard requires + // instance method setters. Values are global, and can be accessed using static getters + private List routingPaths = ImmutableList.of("/oauth2"); + private List deletePaths = ImmutableList.of("/logout", "/oauth2/logout"); + private Duration lifetime = Duration.valueOf("10m"); + + public List getDeletePaths() + { + return deletePaths; + } + + public void setDeletePaths(List deletePaths) + { + this.deletePaths = deletePaths; + } + + public List getRoutingPaths() + { + return routingPaths; + } + + public void setRoutingPaths(List routingPaths) + { + this.routingPaths = routingPaths; + } + + public Duration getLifetime() + { + return lifetime; + } + + public void setLifetime(String lifetime) + { + this.lifetime = Duration.valueOf(lifetime); + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java new file mode 100644 index 000000000..dfbd14560 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/config/OAuth2GatewayCookieConfigurationPropertiesProvider.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.config; + +import io.airlift.units.Duration; + +import java.util.List; + +public class OAuth2GatewayCookieConfigurationPropertiesProvider +{ + private static final OAuth2GatewayCookieConfigurationPropertiesProvider instance = new OAuth2GatewayCookieConfigurationPropertiesProvider(); + + private OAuth2GatewayCookieConfiguration oAuth2GatewayCookieConfiguration; + + private OAuth2GatewayCookieConfigurationPropertiesProvider() + {} + + public static OAuth2GatewayCookieConfigurationPropertiesProvider getInstance() + { + return instance; + } + + public void initialize(OAuth2GatewayCookieConfiguration oAuth2GatewayCookieConfiguration) + { + this.oAuth2GatewayCookieConfiguration = oAuth2GatewayCookieConfiguration; + } + + public List getDeletePaths() + { + ensureInitialized(); + return oAuth2GatewayCookieConfiguration.getDeletePaths(); + } + + public List getRoutingPaths() + { + ensureInitialized(); + return oAuth2GatewayCookieConfiguration.getRoutingPaths(); + } + + public Duration getLifetime() + { + ensureInitialized(); + return oAuth2GatewayCookieConfiguration.getLifetime(); + } + + private void ensureInitialized() + { + if (oAuth2GatewayCookieConfiguration == null) { + throw new IllegalStateException("getInstance.initialize(OAuth2GatewayCookieConfiguration) must be called before use"); + } + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java index 792496c30..4a401a35e 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java @@ -15,13 +15,18 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; import com.google.common.io.CharStreams; import io.airlift.log.Logger; +import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider; +import io.trino.gateway.ha.router.GatewayCookie; +import io.trino.gateway.ha.router.OAuth2GatewayCookie; import io.trino.gateway.ha.router.QueryHistoryManager; import io.trino.gateway.ha.router.RoutingGroupSelector; import io.trino.gateway.ha.router.RoutingManager; import io.trino.gateway.proxyserver.ProxyHandler; import io.trino.gateway.proxyserver.wrapper.MultiReadHttpServletRequest; +import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import jakarta.ws.rs.HttpMethod; @@ -32,9 +37,11 @@ import java.io.IOException; import java.net.URI; import java.net.URISyntaxException; +import java.util.Arrays; import java.util.Base64; import java.util.HashMap; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -84,6 +91,7 @@ public class QueryIdCachingProxyHandler private final ProxyHandlerStats proxyHandlerStats; private final List extraWhitelistPaths; private final String applicationEndpoint; + private final boolean cookiesEnabled; public QueryIdCachingProxyHandler( QueryHistoryManager queryHistoryManager, @@ -99,6 +107,7 @@ public QueryIdCachingProxyHandler( this.queryHistoryManager = queryHistoryManager; this.extraWhitelistPaths = extraWhitelistPaths; this.applicationEndpoint = "http://localhost:" + serverApplicationPort; + cookiesEnabled = GatewayCookieConfigurationPropertiesProvider.getInstance().isEnabled(); } protected static String extractQueryIdIfPresent(String path, String queryParams) @@ -269,9 +278,23 @@ private boolean isPathWhiteListed(String path) || extraWhitelistPaths.stream().anyMatch(s -> path.startsWith(s)); } - public boolean handleAuthRequest(HttpServletRequest request) + @Override + public List generateDeleteCookieList(HttpServletRequest clientRequest) { - return true; + if (!cookiesEnabled || clientRequest.getCookies() == null) { + return ImmutableList.of(); + } + + return Arrays.stream(clientRequest.getCookies()) + .filter(c -> c.getName().startsWith(GatewayCookie.PREFIX)) + .map(GatewayCookie::fromCookie) + .filter(c -> !c.isValid() || c.matchesDeletePath(clientRequest.getRequestURI())) + .map(GatewayCookie::toCookie) + .peek(c -> { + c.setValue("delete"); + c.setMaxAge(0); + }) + .toList(); } @Override @@ -281,10 +304,10 @@ public String rewriteTarget(HttpServletRequest request) return buildUriWithNewBackend(applicationEndpoint, request); } - if (isUseStickyRouting(request)) { - String backend = getPreviousBackend(request); - logRewrite(backend, request); - return buildUriWithNewBackend(backend, request); + Optional previousBackend = getPreviousBackend(request); + if (previousBackend.isPresent()) { + logRewrite(previousBackend.orElseThrow(), request); + return previousBackend.map(b -> buildUriWithNewBackend(b, request)).orElseThrow(); } String backend = getBackendFromRoutingGroup(request); @@ -322,21 +345,27 @@ private String getBackendFromRoutingGroup(HttpServletRequest request) return routingManager.provideAdhocBackend(user); } - private String getPreviousBackend(HttpServletRequest request) + private Optional getPreviousBackend(HttpServletRequest request) { String queryId = extractQueryIdIfPresent(request); - - // Find query id and get url from cache if (!isNullOrEmpty(queryId)) { - return routingManager.findBackendForQueryId(queryId); + return Optional.of(routingManager.findBackendForQueryId(queryId)); + } + if (cookiesEnabled && request.getCookies() != null) { + List cookies = Arrays.stream(request.getCookies()) + .filter(c -> c.getName().startsWith(GatewayCookie.PREFIX)) + .map(GatewayCookie::fromCookie) + .filter(GatewayCookie::isValid) + .filter(c -> !isNullOrEmpty(c.getBackend())) + .filter(c -> c.matchesRoutingPath(request.getRequestURI())) + .sorted() + .toList(); + if (!cookies.isEmpty()) { + return Optional.of(cookies.getFirst().getBackend()); + } } - log.error("No backend found for queryId %s", queryId); - return getBackendFromRoutingGroup(request); - } - private boolean isUseStickyRouting(HttpServletRequest request) - { - return !isNullOrEmpty(extractQueryIdIfPresent(request)); + return Optional.empty(); } @Override @@ -349,50 +378,14 @@ protected void postConnectionHook( Callback callback) { try { - String requestPath = request.getRequestURI(); - if (requestPath.startsWith(V1_STATEMENT_PATH) - && request.getMethod().equals(HttpMethod.POST)) { - String output; - boolean isGZipEncoding = isGZipEncoding(response); - if (isGZipEncoding) { - output = plainTextFromGz(buffer); - } - else { - output = new String(buffer); - } - log.debug("For Request [%s] got Response output [%s]", request.getRequestURI(), output); - - QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request); - log.debug("Extracting Proxy destination : [%s] for request : [%s]", - queryDetail.getBackendUrl(), request.getRequestURI()); - - if (response.getStatus() == HttpStatus.OK_200) { - HashMap results = OBJECT_MAPPER.readValue(output, HashMap.class); - queryDetail.setQueryId(results.get("id")); - - if (!isNullOrEmpty(queryDetail.getQueryId())) { - routingManager.setBackendForQueryId( - queryDetail.getQueryId(), queryDetail.getBackendUrl()); - log.debug( - "QueryId [%s] mapped with proxy [%s]", - queryDetail.getQueryId(), - queryDetail.getBackendUrl()); - } - else { - log.debug("QueryId [%s] could not be cached", queryDetail.getQueryId()); - } - } - else { - log.error( - "Non OK HTTP Status code with response [%s] , Status code [%s]", - output, - response.getStatus()); - } - // Saving history at gateway. - queryHistoryManager.submitQueryDetail(queryDetail); + if (request.getRequestURI().startsWith(V1_STATEMENT_PATH) && request.getMethod().equals(HttpMethod.POST)) { + recordBackendForQueryId(request, response, buffer); } - else { - log.debug("SKIPPING For %s", requestPath); + else if (cookiesEnabled && request.getRequestURI().startsWith(OAuth2GatewayCookie.OAUTH2_PATH) + && !(request.getCookies() != null + && Arrays.stream(request.getCookies()).anyMatch(c -> c.getName().equals(OAuth2GatewayCookie.NAME)))) { + GatewayCookie oauth2Cookie = new OAuth2GatewayCookie(request.getHeader(PROXY_TARGET_HEADER)); + response.addCookie(oauth2Cookie.toCookie()); } } catch (Exception e) { @@ -401,6 +394,42 @@ protected void postConnectionHook( super.postConnectionHook(request, response, buffer, offset, length, callback); } + void recordBackendForQueryId(HttpServletRequest request, HttpServletResponse response, byte[] buffer) + throws IOException + { + String output; + boolean isGZipEncoding = isGZipEncoding(response); + if (isGZipEncoding) { + output = plainTextFromGz(buffer); + } + else { + output = new String(buffer); + } + log.debug("For Request [%s] got Response output [%s]", request.getRequestURI(), output); + + QueryHistoryManager.QueryDetail queryDetail = getQueryDetailsFromRequest(request); + + if (queryDetail.getBackendUrl() == null) { + log.error("Server response to request %s does not contain proxytarget header", request.getRequestURI()); + } + log.debug("Extracting Proxy destination : [%s] for request : [%s]", queryDetail.getBackendUrl(), request.getRequestURI()); + + if (response.getStatus() == HttpStatus.OK_200) { + HashMap results = OBJECT_MAPPER.readValue(output, HashMap.class); + queryDetail.setQueryId(results.get("id")); + + if (!isNullOrEmpty(queryDetail.getQueryId())) { + routingManager.setBackendForQueryId(queryDetail.getQueryId(), queryDetail.getBackendUrl()); + log.debug("QueryId [%s] mapped with proxy [%s]", queryDetail.getQueryId(), queryDetail.getBackendUrl()); + } + } + else { + log.error("Non OK HTTP Status code with response [%s] , Status code [%s]", output, response.getStatus()); + } + // Save history in Trino Gateway. + queryHistoryManager.submitQueryDetail(queryDetail); + } + private QueryHistoryManager.QueryDetail getQueryDetailsFromRequest(HttpServletRequest request) throws IOException { diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java index df8a5239d..8ee294d5a 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/module/HaGatewayProviderModule.java @@ -27,7 +27,9 @@ import io.dropwizard.jetty.HttpConnectorFactory; import io.trino.gateway.ha.config.AuthenticationConfiguration; import io.trino.gateway.ha.config.AuthorizationConfiguration; +import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider; import io.trino.gateway.ha.config.HaGatewayConfiguration; +import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider; import io.trino.gateway.ha.config.RequestRouterConfiguration; import io.trino.gateway.ha.config.RoutingRulesConfiguration; import io.trino.gateway.ha.config.UserConfiguration; @@ -86,6 +88,12 @@ public HaGatewayProviderModule(HaGatewayConfiguration configuration, Environment authenticationFilter = getAuthFilter(configuration); backendStateConnectionManager = new BackendStateManager(); extraWhitelistPaths = configuration.getExtraWhitelistPaths(); + + GatewayCookieConfigurationPropertiesProvider gatewayCookieConfigurationPropertiesProvider = GatewayCookieConfigurationPropertiesProvider.getInstance(); + gatewayCookieConfigurationPropertiesProvider.initialize(configuration.getGatewayCookieConfiguration()); + + OAuth2GatewayCookieConfigurationPropertiesProvider oAuth2GatewayCookieConfigurationPropertiesProvider = OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance(); + oAuth2GatewayCookieConfigurationPropertiesProvider.initialize(configuration.getOauth2GatewayCookieConfiguration()); } private LbOAuthManager getOAuthManager(HaGatewayConfiguration configuration) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java new file mode 100644 index 000000000..3e86f2224 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/GatewayCookie.java @@ -0,0 +1,289 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyOrder; +import com.google.common.hash.Hashing; +import io.airlift.json.JsonCodec; +import io.airlift.log.Logger; +import io.airlift.units.Duration; +import io.trino.gateway.ha.config.GatewayCookieConfigurationPropertiesProvider; +import jakarta.servlet.http.Cookie; + +import java.util.Base64; +import java.util.List; + +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; + +@JsonPropertyOrder(alphabetic = true) +public class GatewayCookie + implements Comparable +{ + private static final Logger log = Logger.get(GatewayCookie.class); + private String signature; + private final UnsignedGatewayCookie unsignedGatewayCookie; + private final GatewayCookieConfigurationPropertiesProvider gatewayCookieConfigurationPropertiesProvider = GatewayCookieConfigurationPropertiesProvider.getInstance(); + + public static final String PREFIX = "TG."; + + public static final JsonCodec CODEC = JsonCodec.jsonCodec(GatewayCookie.class); + + @JsonCreator + public GatewayCookie( + @JsonProperty("ts") Long ts, + @JsonProperty("name") String name, + @JsonProperty("payload") String payload, + @JsonProperty("routingPaths") List routingPaths, + @JsonProperty("deletePaths") List deletePaths, + @JsonProperty("backend") String backend, + @JsonProperty("priority") Integer priority, + @JsonProperty("ttl") Duration ttl, + @JsonProperty("signature") String signature) + { + this.unsignedGatewayCookie = new UnsignedGatewayCookie( + requireNonNull(ts), + requireNonNull(name), + payload, + backend, + requireNonNull(routingPaths), + requireNonNull(deletePaths), + requireNonNull(ttl), + priority); + this.signature = signature; + } + + public GatewayCookie(String name, String payload, String backend, List routingPaths, List deletePaths, Duration ttl, int priority) + { + this.unsignedGatewayCookie = new UnsignedGatewayCookie( + System.currentTimeMillis(), + requireNonNull(name.startsWith(PREFIX) ? name : PREFIX + name), + payload, + backend, + requireNonNull(routingPaths), + requireNonNull(deletePaths), + requireNonNull(ttl), + priority); + signature = computeSignature(); + } + + @JsonProperty + public Long getTs() + { + return unsignedGatewayCookie.getTs(); + } + + @JsonProperty + public String getName() + { + return unsignedGatewayCookie.getName(); + } + + @JsonProperty + public String getPayload() + { + return unsignedGatewayCookie.getPayload(); + } + + @JsonProperty + public String getBackend() + { + return unsignedGatewayCookie.getBackend(); + } + + @JsonProperty + public int getPriority() + { + return unsignedGatewayCookie.getPriority(); + } + + @JsonProperty + public List getRoutingPaths() + { + return unsignedGatewayCookie.getRoutingPaths(); + } + + @JsonProperty + public List getDeletePaths() + { + return unsignedGatewayCookie.getDeletePaths(); + } + + @JsonProperty + public Duration getTtl() + { + return unsignedGatewayCookie.getTtl(); + } + + @JsonProperty + public String getSignature() + { + return signature; + } + + public void setTs(Long ts) + { + unsignedGatewayCookie.setTs(ts); + } + + private String computeSignature() + { + return Hashing.hmacSha256(gatewayCookieConfigurationPropertiesProvider.getCookieSigningKey()) + .hashString(UnsignedGatewayCookie.CODEC.toJson(unsignedGatewayCookie), UTF_8) + .toString(); + } + + @Override + public int compareTo(GatewayCookie o) + { + int priorityDelta = unsignedGatewayCookie.getPriority() - o.getPriority(); + return priorityDelta != 0 ? priorityDelta : (int) (unsignedGatewayCookie.getTs() - o.getTs()); + } + + public Cookie toCookie() + { + Cookie cookie = new Cookie(unsignedGatewayCookie.getName(), Base64.getUrlEncoder().encodeToString(CODEC.toJson(this).getBytes(UTF_8))); + cookie.setMaxAge((int) unsignedGatewayCookie.getTtl().toMillis() / 1000); + return cookie; + } + + public static GatewayCookie fromCookie(Cookie cookie) + { + return GatewayCookie.CODEC.fromJson(Base64.getUrlDecoder().decode(cookie.getValue())); + } + + public boolean matchesRoutingPath(String path) + { + if (matchesDeletePath(path)) { + return false; + } + + return unsignedGatewayCookie.getRoutingPaths().stream().anyMatch(path::startsWith); + } + + public boolean matchesDeletePath(String path) + { + return unsignedGatewayCookie.getDeletePaths().contains(path); + } + + public boolean isValid() + { + if (System.currentTimeMillis() > unsignedGatewayCookie.getTs() + unsignedGatewayCookie.getTtl().toMillis()) { + return false; + } + + if (isNullOrEmpty(signature) || !signature.equals(computeSignature())) { + log.error("Invalid cookie: %s", CODEC.toJson(this)); + throw new IllegalArgumentException("Invalid cookie signature"); + } + + return true; + } + + @JsonPropertyOrder(alphabetic = true) + public static class UnsignedGatewayCookie + { + public static final JsonCodec CODEC = JsonCodec.jsonCodec(UnsignedGatewayCookie.class); + private Long ts; // timestamp. The shortened name saves 8 bytes of cookie size + private final String name; + private final String payload; + private final List routingPaths; + + private final List deletePaths; + + private final int priority; + private final Duration ttl; + private final String backend; + + public UnsignedGatewayCookie(GatewayCookie gatewayCookie) + { + this.ts = gatewayCookie.getTs(); + this.name = gatewayCookie.getName(); + this.payload = gatewayCookie.getPayload(); + this.routingPaths = gatewayCookie.getRoutingPaths(); + this.deletePaths = gatewayCookie.getDeletePaths(); + this.priority = gatewayCookie.getPriority(); + this.ttl = gatewayCookie.getTtl(); + this.backend = gatewayCookie.getBackend(); + } + + public UnsignedGatewayCookie(Long ts, String name, String payload, String backend, List routingPaths, List deletePaths, Duration ttl, int priority) + { + this.name = name.startsWith(PREFIX) ? name : PREFIX + name; + this.payload = payload; + this.backend = backend; + this.routingPaths = routingPaths; + this.deletePaths = deletePaths; + this.ttl = ttl; + this.ts = ts; + this.priority = priority; + } + + @JsonProperty + public Long getTs() + { + return ts; + } + + public void setTs(Long ts) + { + this.ts = ts; + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public String getPayload() + { + return payload; + } + + @JsonProperty + public String getBackend() + { + return backend; + } + + @JsonProperty + public int getPriority() + { + return priority; + } + + @JsonProperty + public List getRoutingPaths() + { + return routingPaths; + } + + @JsonProperty + public List getDeletePaths() + { + return deletePaths; + } + + @JsonProperty + public Duration getTtl() + { + return ttl; + } + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java new file mode 100644 index 000000000..47be10ef7 --- /dev/null +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/router/OAuth2GatewayCookie.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.gateway.ha.router; + +import com.google.common.collect.ImmutableList; +import io.trino.gateway.ha.config.OAuth2GatewayCookieConfigurationPropertiesProvider; + +public class OAuth2GatewayCookie + extends GatewayCookie +{ + public static final String NAME = GatewayCookie.PREFIX + "OAUTH2"; + public static final String OAUTH2_PATH = "/oauth2"; + + public OAuth2GatewayCookie(String backend) + { + super( + NAME, + null, + backend, + ImmutableList.of(OAUTH2_PATH), + OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getDeletePaths(), + OAuth2GatewayCookieConfigurationPropertiesProvider.getInstance().getLifetime(), + 0); + } +} diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java index 088536940..bd1a94bdd 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyHandler.java @@ -13,7 +13,9 @@ */ package io.trino.gateway.proxyserver; +import com.google.common.collect.ImmutableList; import io.airlift.log.Logger; +import jakarta.servlet.http.Cookie; import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletResponse; import org.eclipse.jetty.client.api.Request; @@ -27,6 +29,7 @@ import java.nio.charset.Charset; import java.util.Arrays; import java.util.Enumeration; +import java.util.List; import java.util.zip.GZIPInputStream; /* Order of control => rewriteTarget, preConnectionHook, postConnectionHook. */ @@ -114,4 +117,9 @@ protected boolean isCompressed(final byte[] compressed) return (compressed[0] == (byte) GZIPInputStream.GZIP_MAGIC) && (compressed[1] == (byte) (GZIPInputStream.GZIP_MAGIC >> 8)); } + + public List generateDeleteCookieList(HttpServletRequest clientRequest) + { + return ImmutableList.of(); + } } diff --git a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java index 5267d7c89..da013d613 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java +++ b/gateway-ha/src/main/java/io/trino/gateway/proxyserver/ProxyServletImpl.java @@ -98,6 +98,16 @@ protected String rewriteTarget(HttpServletRequest request) return target; } + @Override + protected void onServerResponseHeaders( + HttpServletRequest clientRequest, + HttpServletResponse proxyResponse, + Response serverResponse) + { + this.proxyHandler.generateDeleteCookieList(clientRequest).forEach(proxyResponse::addCookie); + super.onServerResponseHeaders(clientRequest, proxyResponse, serverResponse); + } + /** * Customize the response returned from remote server. */ diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java b/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java index 8cc8a6a75..69f42be65 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/HaGatewayTestUtils.java @@ -19,8 +19,10 @@ import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; +import okhttp3.mockwebserver.Dispatcher; import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; import org.jdbi.v3.core.Handle; import org.jdbi.v3.core.Jdbi; @@ -28,6 +30,7 @@ import java.io.FileWriter; import java.io.IOException; import java.io.InputStream; +import java.util.Map; import java.util.Random; import java.util.Scanner; @@ -66,6 +69,23 @@ public static void prepareMockBackend( .setResponseCode(200)); } + public static void setPathSpecificResponses( + MockWebServer backend, Map pathResponseMap) + { + Dispatcher dispatcher = new Dispatcher() + { + @Override + public MockResponse dispatch(RecordedRequest request) + { + if (pathResponseMap.containsKey(request.getPath())) { + return new MockResponse().setResponseCode(200).setBody(pathResponseMap.get(request.getPath())); + } + return new MockResponse().setResponseCode(404); + } + }; + backend.setDispatcher(dispatcher); + } + public static TestConfig buildGatewayConfigAndSeedDb(int routerPort, String configFile) throws IOException { diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java index a213acec5..b123415e7 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/TestGatewayHaMultipleBackend.java @@ -14,14 +14,17 @@ package io.trino.gateway.ha; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.collect.ImmutableMap; import io.trino.gateway.ha.config.ProxyBackendConfiguration; +import io.trino.gateway.ha.router.GatewayCookie; +import io.trino.gateway.ha.router.OAuth2GatewayCookie; +import okhttp3.Cookie; import okhttp3.MediaType; import okhttp3.OkHttpClient; import okhttp3.Request; import okhttp3.RequestBody; import okhttp3.Response; import okhttp3.mockwebserver.MockWebServer; -import okhttp3.mockwebserver.RecordedRequest; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; @@ -29,6 +32,12 @@ import org.junit.jupiter.api.TestInstance.Lifecycle; import org.testcontainers.containers.TrinoContainer; +import java.io.IOException; +import java.util.Base64; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + import static org.assertj.core.api.Assertions.assertThat; @TestInstance(Lifecycle.PER_CLASS) @@ -40,6 +49,11 @@ public class TestGatewayHaMultipleBackend private TrinoContainer adhocTrino; private TrinoContainer scheduledTrino; + public static String oauthInitiatePath = OAuth2GatewayCookie.OAUTH2_PATH; + public static String oauthCallbackPath = oauthInitiatePath + "/callback"; + public static String oauthInitialResponse = "abc"; + public static String oauthCallbackResponse = "xyz"; + final int routerPort = 20000 + (int) (Math.random() * 1000); final int customBackendPort = 21000 + (int) (Math.random() * 1000); @@ -59,7 +73,11 @@ public void setup() int backend1Port = adhocTrino.getMappedPort(8080); int backend2Port = scheduledTrino.getMappedPort(8080); - HaGatewayTestUtils.prepareMockBackend(customBackend, customBackendPort, CUSTOM_RESPONSE); + HaGatewayTestUtils.prepareMockBackend(customBackend, customBackendPort, "default custom response"); + HaGatewayTestUtils.setPathSpecificResponses(customBackend, ImmutableMap.of( + oauthInitiatePath, oauthInitialResponse, + oauthCallbackPath, oauthCallbackResponse, + CUSTOM_PATH, CUSTOM_RESPONSE)); // seed database HaGatewayTestUtils.TestConfig testConfig = @@ -93,9 +111,6 @@ public void testCustomPath() .build(); Response response1 = httpClient.newCall(request1).execute(); assertThat(response1.body().string()).isEqualTo(CUSTOM_RESPONSE); - RecordedRequest recordedRequest = customBackend.takeRequest(); - assertThat(recordedRequest.getMethod()).isEqualTo("POST"); - assertThat(recordedRequest.getPath()).isEqualTo(CUSTOM_PATH); Request request2 = new Request.Builder() @@ -162,6 +177,101 @@ public void testBackendConfiguration() assertThat(backendConfiguration[2].getExternalUrl()).isEqualTo("externalUrl"); } + @Test + public void testCookieBasedRouting() + throws IOException + { + // This simulates the Trino oauth handshake + OkHttpClient httpClient = new OkHttpClient.Builder() + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(10, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .build(); + String oauthInitiateBody = "anything"; + RequestBody requestBody = + RequestBody.create( + MediaType.parse("application/json; charset=utf-8"), oauthInitiateBody); + + Request initiateRequest = + new Request.Builder() + .url("http://localhost:" + routerPort + oauthInitiatePath) + .post(requestBody) + .addHeader("X-Trino-Routing-Group", "custom") + .build(); + Response initiateResponse = httpClient.newCall(initiateRequest).execute(); + assertThat(initiateResponse.header("set-cookie")).isNotEmpty(); + + Request callbackRequest = + new Request.Builder() + .url("http://localhost:" + routerPort + oauthCallbackPath) + .post(requestBody) + .addHeader("Cookie", initiateResponse.header("set-cookie")) + .build(); + Response callbackResponse = httpClient.newCall(callbackRequest).execute(); + assertThat(callbackResponse.body().string()).isEqualTo(oauthCallbackResponse); + + Request logoutRequest = + new Request.Builder() + .url("http://localhost:" + routerPort + "/custom/logout") + .post(requestBody) + .addHeader("Cookie", initiateResponse.header("set-cookie")) + .build(); + Response logoutResponse = httpClient.newCall(logoutRequest).execute(); + + List cookies = Cookie.parseAll(logoutResponse.request().url(), logoutResponse.headers()); + Optional cookie = cookies.stream().filter(c -> c.name().equals(OAuth2GatewayCookie.NAME)).findAny(); + assertThat(cookie).isNotEmpty(); + assertThat(cookie.orElseThrow().value()).isEqualTo("delete"); + // expires-at has been deprecated in favor of max-age. However, okhttp3 does not expose a max-age property, + // but instead sets expires-at to Long.MIN_VALUE when max-age is set to 0 + // https://github.com/square/okhttp/blob/577d621585f7525d3e98a9161bc26d2965686538/okhttp/src/main/kotlin/okhttp3/Cookie.kt#L673 + assertThat(cookie.orElseThrow().expiresAt()).isEqualTo(Long.MIN_VALUE); + } + + @Test + public void testCookieSigning() + throws IOException + { + OkHttpClient httpClient = new OkHttpClient.Builder() + .connectTimeout(10, TimeUnit.SECONDS) + .writeTimeout(10, TimeUnit.SECONDS) + .readTimeout(60, TimeUnit.SECONDS) + .build(); + String oauthInitiateBody = "anything"; + RequestBody requestBody = + RequestBody.create( + MediaType.parse("application/json; charset=utf-8"), oauthInitiateBody); + + Request initiateRequest = + new Request.Builder() + .url("http://localhost:" + routerPort + oauthInitiatePath) + .post(requestBody) + .addHeader("X-Trino-Routing-Group", "custom") + .build(); + Response initiateResponse = httpClient.newCall(initiateRequest).execute(); + String cookieHeader = initiateResponse.header("set-cookie"); + assertThat(cookieHeader).isNotEmpty(); + List cookies = Cookie.parseAll(initiateResponse.request().url(), initiateResponse.headers()); + Optional cookie = cookies.stream().filter(c -> c.name().equals(OAuth2GatewayCookie.NAME)).findAny(); + assertThat(cookie).isNotEmpty(); + + GatewayCookie gatewayCookie = GatewayCookie.CODEC.fromJson(Base64.getUrlDecoder().decode(cookie.orElseThrow().value())); + assertThat(gatewayCookie.getSignature()).isNotEmpty(); + + // Tamper with values. This will cause the cookie to be ignored because its values will not match the signature, + // causing the request will be routed to the adhoc backend + gatewayCookie.setTs(gatewayCookie.getTs() + 1000); + jakarta.servlet.http.Cookie tamperedCookie = gatewayCookie.toCookie(); + Request callbackRequest = + new Request.Builder() + .url("http://localhost:" + routerPort + oauthCallbackPath) + .post(requestBody) + .addHeader("Cookie", String.format("%s=%s", tamperedCookie.getName(), tamperedCookie.getValue())) + .build(); + Response callbackResponse = httpClient.newCall(callbackRequest).execute(); + assertThat(callbackResponse.code()).isEqualTo(500); + } + @AfterAll public void cleanup() { diff --git a/gateway-ha/src/test/resources/test-config-template.yml b/gateway-ha/src/test/resources/test-config-template.yml index 8983890ed..4781e179a 100644 --- a/gateway-ha/src/test/resources/test-config-template.yml +++ b/gateway-ha/src/test/resources/test-config-template.yml @@ -26,5 +26,13 @@ managedApps: extraWhitelistPaths: - "/v1/custom" +gatewayCookieConfiguration: + enabled: true + cookieSigningSecret: "kjlhbfrewbyuo452cds3dc1234ancdsjh" + +oauth2GatewayCookieConfiguration: + deletePaths: + - "/custom/logout" + logging: type: external