Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.codahale.metrics.Meter;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.io.CharStreams;
import io.trino.gateway.ha.router.QueryHistoryManager;
import io.trino.gateway.ha.router.RoutingGroupSelector;
Expand All @@ -41,6 +40,8 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static com.google.common.base.Strings.isNullOrEmpty;

public class QueryIdCachingProxyHandler
extends ProxyHandler
{
Expand Down Expand Up @@ -71,8 +72,8 @@ public class QueryIdCachingProxyHandler
private final QueryHistoryManager queryHistoryManager;

private final Meter requestMeter;
private final int serverApplicationPort;
private final List<String> extraWhitelistPaths;
private final String applicationEndpoint;

public QueryIdCachingProxyHandler(
QueryHistoryManager queryHistoryManager,
Expand All @@ -85,8 +86,8 @@ public QueryIdCachingProxyHandler(
this.routingManager = routingManager;
this.routingGroupSelector = routingGroupSelector;
this.queryHistoryManager = queryHistoryManager;
this.serverApplicationPort = serverApplicationPort;
this.extraWhitelistPaths = extraWhitelistPaths;
this.applicationEndpoint = "http://localhost:" + serverApplicationPort;
}

protected static String extractQueryIdIfPresent(String path, String queryParams)
Expand Down Expand Up @@ -151,7 +152,7 @@ static String getQueryUser(HttpServletRequest request)
{
String trinoUser = request.getHeader(USER_HEADER);

if (!Strings.isNullOrEmpty(trinoUser)) {
if (!isNullOrEmpty(trinoUser)) {
log.info("user from {}", USER_HEADER);
return trinoUser;
}
Expand All @@ -171,7 +172,7 @@ static String getQueryUser(HttpServletRequest request)
}

String headerInfo = header.substring(space + 1).trim();
if (Strings.isNullOrEmpty(headerInfo)) {
if (isNullOrEmpty(headerInfo)) {
log.error("The encoded value of basic auth doesn't exist");
return user;
}
Expand All @@ -191,7 +192,7 @@ protected String extractQueryIdIfPresent(HttpServletRequest request)
String queryParams = request.getQueryString();
try {
String queryText = CharStreams.toString(request.getReader());
if (!Strings.isNullOrEmpty(queryText)
if (!isNullOrEmpty(queryText)
&& queryText.toLowerCase().contains("system.runtime.kill_query")) {
// extract and return the queryId
String[] parts = queryText.split(",");
Expand All @@ -200,7 +201,7 @@ protected String extractQueryIdIfPresent(HttpServletRequest request)
Matcher m = EXTRACT_BETWEEN_SINGLE_QUOTES.matcher(part);
if (m.find()) {
String queryQuoted = m.group();
if (!Strings.isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
if (!isNullOrEmpty(queryQuoted) && queryQuoted.length() > 0) {
return queryQuoted.substring(1, queryQuoted.length() - 1);
}
}
Expand Down Expand Up @@ -251,11 +252,6 @@ private boolean isPathWhiteListed(String path)
|| extraWhitelistPaths.stream().anyMatch(s -> path.startsWith(s));
}

public boolean isAuthEnabled()
{
return false;
}

public boolean handleAuthRequest(HttpServletRequest request)
{
return true;
Expand All @@ -264,54 +260,66 @@ public boolean handleAuthRequest(HttpServletRequest request)
@Override
public String rewriteTarget(HttpServletRequest request)
{
/* Here comes the load balancer / gateway */
String backendAddress = "http://localhost:" + serverApplicationPort;
if (!isPathWhiteListed(request.getRequestURI())) {
return buildUriWithNewBackend(applicationEndpoint, request);
}

// Only load balance trino query APIs.
if (isPathWhiteListed(request.getRequestURI())) {
String queryId = extractQueryIdIfPresent(request);
if (isUseStickyRouting(request)) {
String backend = getPreviousBackend(request);
logRewrite(backend, request);
return buildUriWithNewBackend(backend, request);
}

// Find query id and get url from cache
if (!Strings.isNullOrEmpty(queryId)) {
backendAddress = routingManager.findBackendForQueryId(queryId);
}
else {
String routingGroup = routingGroupSelector.findRoutingGroup(request);
String user = request.getHeader(USER_HEADER);
if (!Strings.isNullOrEmpty(routingGroup)) {
// This falls back on adhoc backend if there are no cluster found for the routing group.
backendAddress = routingManager.provideBackendForRoutingGroup(routingGroup, user);
}
else {
backendAddress = routingManager.provideAdhocBackend(user);
}
}
// set target backend so that we could save queryId to backend mapping later.
((MultiReadHttpServletRequest) request).addHeader(PROXY_TARGET_HEADER, backendAddress);
String backend = getBackendFromRoutingGroup(request);
// set target backend so that we could save queryId to backend mapping later.
((MultiReadHttpServletRequest) request).addHeader(PROXY_TARGET_HEADER, backend);
logRewrite(backend, request);

return buildUriWithNewBackend(backend, request);
}

private void logRewrite(String newBackend, HttpServletRequest request)
{
log.info("Rerouting [{}://{}:{}{}{}]--> [{}]",
request.getScheme(),
request.getRemoteHost(),
request.getServerPort(),
request.getRequestURI(),
(request.getQueryString() != null ? "?" + request.getQueryString() : ""),
buildUriWithNewBackend(newBackend, request));
}

private String buildUriWithNewBackend(String backendHost, HttpServletRequest request)
{
return backendHost + request.getRequestURI() + (request.getQueryString() != null ? "?" + request.getQueryString() : "");
}

private String getBackendFromRoutingGroup(HttpServletRequest request)
{
String routingGroup = routingGroupSelector.findRoutingGroup(request);
String user = request.getHeader(USER_HEADER);
if (!isNullOrEmpty(routingGroup)) {
// This falls back on adhoc backend if there are no cluster found for the routing group.
return routingManager.provideBackendForRoutingGroup(routingGroup, user);
}
if (isAuthEnabled() && request.getHeader("Authorization") != null) {
if (!handleAuthRequest(request)) {
// This implies the AuthRequest was not authenticated, hence we error out from here.
log.info("Could not authenticate Request: " + request);
return null;
}
return routingManager.provideAdhocBackend(user);
}

private String getPreviousBackend(HttpServletRequest request)
{
String queryId = extractQueryIdIfPresent(request);

// Find query id and get url from cache
if (!isNullOrEmpty(queryId)) {
return routingManager.findBackendForQueryId(queryId);
}
String targetLocation =
backendAddress
+ request.getRequestURI()
+ (request.getQueryString() != null ? "?" + request.getQueryString() : "");

String originalLocation =
request.getScheme()
+ "://"
+ request.getRemoteHost()
+ ":"
+ request.getServerPort()
+ request.getRequestURI()
+ (request.getQueryString() != null ? "?" + request.getQueryString() : "");

log.info("Rerouting [{}]--> [{}]", originalLocation, targetLocation);
return targetLocation;
log.error("No backend found for queryId {}", queryId);
return getBackendFromRoutingGroup(request);
}

private boolean isUseStickyRouting(HttpServletRequest request)
{
return !isNullOrEmpty(extractQueryIdIfPresent(request));
}

@Override
Expand Down Expand Up @@ -345,7 +353,7 @@ protected void postConnectionHook(
HashMap<String, String> results = OBJECT_MAPPER.readValue(output, HashMap.class);
queryDetail.setQueryId(results.get("id"));

if (!Strings.isNullOrEmpty(queryDetail.getQueryId())) {
if (!isNullOrEmpty(queryDetail.getQueryId())) {
routingManager.setBackendForQueryId(
queryDetail.getQueryId(), queryDetail.getBackendUrl());
log.debug(
Expand Down