From 0b68c1f6625ff9424aca1aa9b5c3ac793dbc8974 Mon Sep 17 00:00:00 2001 From: Will Morrison Date: Fri, 2 Feb 2024 12:45:49 -0500 Subject: [PATCH] Refactor rewriteTarget --- .../handler/QueryIdCachingProxyHandler.java | 122 ++++++++++-------- 1 file changed, 65 insertions(+), 57 deletions(-) diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java index e37187902..32671bb41 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/QueryIdCachingProxyHandler.java @@ -16,7 +16,6 @@ import com.codahale.metrics.Meter; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.base.Splitter; -import com.google.common.base.Strings; import com.google.common.io.CharStreams; import io.trino.gateway.ha.router.QueryHistoryManager; import io.trino.gateway.ha.router.RoutingGroupSelector; @@ -41,6 +40,8 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import static com.google.common.base.Strings.isNullOrEmpty; + public class QueryIdCachingProxyHandler extends ProxyHandler { @@ -71,8 +72,8 @@ public class QueryIdCachingProxyHandler private final QueryHistoryManager queryHistoryManager; private final Meter requestMeter; - private final int serverApplicationPort; private final List extraWhitelistPaths; + private final String applicationEndpoint; public QueryIdCachingProxyHandler( QueryHistoryManager queryHistoryManager, @@ -85,8 +86,8 @@ public QueryIdCachingProxyHandler( this.routingManager = routingManager; this.routingGroupSelector = routingGroupSelector; this.queryHistoryManager = queryHistoryManager; - this.serverApplicationPort = serverApplicationPort; this.extraWhitelistPaths = extraWhitelistPaths; + this.applicationEndpoint = "http://localhost:" + serverApplicationPort; } protected static String extractQueryIdIfPresent(String path, String queryParams) @@ -151,7 +152,7 @@ static String getQueryUser(HttpServletRequest request) { String trinoUser = request.getHeader(USER_HEADER); - if (!Strings.isNullOrEmpty(trinoUser)) { + if (!isNullOrEmpty(trinoUser)) { log.info("user from {}", USER_HEADER); return trinoUser; } @@ -171,7 +172,7 @@ static String getQueryUser(HttpServletRequest request) } String headerInfo = header.substring(space + 1).trim(); - if (Strings.isNullOrEmpty(headerInfo)) { + if (isNullOrEmpty(headerInfo)) { log.error("The encoded value of basic auth doesn't exist"); return user; } @@ -191,7 +192,7 @@ protected String extractQueryIdIfPresent(HttpServletRequest request) String queryParams = request.getQueryString(); try { String queryText = CharStreams.toString(request.getReader()); - if (!Strings.isNullOrEmpty(queryText) + if (!isNullOrEmpty(queryText) && queryText.toLowerCase().contains("system.runtime.kill_query")) { // extract and return the queryId String[] parts = queryText.split(","); @@ -200,7 +201,7 @@ protected String extractQueryIdIfPresent(HttpServletRequest request) Matcher m = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part); if (m.find()) { String queryQuoted = m.group(); - if (!Strings.isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) { + if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) { return queryQuoted.substring(1, queryQuoted.length() - 1); } } @@ -251,11 +252,6 @@ private boolean isPathWhiteListed(String path) || extraWhitelistPaths.stream().anyMatch(s -> path.startsWith(s)); } - public boolean isAuthEnabled() - { - return false; - } - public boolean handleAuthRequest(HttpServletRequest request) { return true; @@ -264,54 +260,66 @@ public boolean handleAuthRequest(HttpServletRequest request) @Override public String rewriteTarget(HttpServletRequest request) { - /* Here comes the load balancer / gateway */ - String backendAddress = "http://localhost:" + serverApplicationPort; + if (!isPathWhiteListed(request.getRequestURI())) { + return buildUriWithNewBackend(applicationEndpoint, request); + } - // Only load balance trino query APIs. - if (isPathWhiteListed(request.getRequestURI())) { - String queryId = extractQueryIdIfPresent(request); + if (isUseStickyRouting(request)) { + String backend = getPreviousBackend(request); + logRewrite(backend, request); + return buildUriWithNewBackend(backend, request); + } - // Find query id and get url from cache - if (!Strings.isNullOrEmpty(queryId)) { - backendAddress = routingManager.findBackendForQueryId(queryId); - } - else { - String routingGroup = routingGroupSelector.findRoutingGroup(request); - String user = request.getHeader(USER_HEADER); - if (!Strings.isNullOrEmpty(routingGroup)) { - // This falls back on adhoc backend if there are no cluster found for the routing group. - backendAddress = routingManager.provideBackendForRoutingGroup(routingGroup, user); - } - else { - backendAddress = routingManager.provideAdhocBackend(user); - } - } - // set target backend so that we could save queryId to backend mapping later. - ((MultiReadHttpServletRequest) request).addHeader(PROXY_TARGET_HEADER, backendAddress); + String backend = getBackendFromRoutingGroup(request); + // set target backend so that we could save queryId to backend mapping later. + ((MultiReadHttpServletRequest) request).addHeader(PROXY_TARGET_HEADER, backend); + logRewrite(backend, request); + + return buildUriWithNewBackend(backend, request); + } + + private void logRewrite(String newBackend, HttpServletRequest request) + { + log.info("Rerouting [{}://{}:{}{}{}]--> [{}]", + request.getScheme(), + request.getRemoteHost(), + request.getServerPort(), + request.getRequestURI(), + (request.getQueryString() != null ? "?" + request.getQueryString() : ""), + buildUriWithNewBackend(newBackend, request)); + } + + private String buildUriWithNewBackend(String backendHost, HttpServletRequest request) + { + return backendHost + request.getRequestURI() + (request.getQueryString() != null ? "?" + request.getQueryString() : ""); + } + + private String getBackendFromRoutingGroup(HttpServletRequest request) + { + String routingGroup = routingGroupSelector.findRoutingGroup(request); + String user = request.getHeader(USER_HEADER); + if (!isNullOrEmpty(routingGroup)) { + // This falls back on adhoc backend if there are no cluster found for the routing group. + return routingManager.provideBackendForRoutingGroup(routingGroup, user); } - if (isAuthEnabled() && request.getHeader("Authorization") != null) { - if (!handleAuthRequest(request)) { - // This implies the AuthRequest was not authenticated, hence we error out from here. - log.info("Could not authenticate Request: " + request); - return null; - } + return routingManager.provideAdhocBackend(user); + } + + private String getPreviousBackend(HttpServletRequest request) + { + String queryId = extractQueryIdIfPresent(request); + + // Find query id and get url from cache + if (!isNullOrEmpty(queryId)) { + return routingManager.findBackendForQueryId(queryId); } - String targetLocation = - backendAddress - + request.getRequestURI() - + (request.getQueryString() != null ? "?" + request.getQueryString() : ""); - - String originalLocation = - request.getScheme() - + "://" - + request.getRemoteHost() - + ":" - + request.getServerPort() - + request.getRequestURI() - + (request.getQueryString() != null ? "?" + request.getQueryString() : ""); - - log.info("Rerouting [{}]--> [{}]", originalLocation, targetLocation); - return targetLocation; + log.error("No backend found for queryId {}", queryId); + return getBackendFromRoutingGroup(request); + } + + private boolean isUseStickyRouting(HttpServletRequest request) + { + return !isNullOrEmpty(extractQueryIdIfPresent(request)); } @Override @@ -345,7 +353,7 @@ protected void postConnectionHook( HashMap results = OBJECT_MAPPER.readValue(output, HashMap.class); queryDetail.setQueryId(results.get("id")); - if (!Strings.isNullOrEmpty(queryDetail.getQueryId())) { + if (!isNullOrEmpty(queryDetail.getQueryId())) { routingManager.setBackendForQueryId( queryDetail.getQueryId(), queryDetail.getBackendUrl()); log.debug(