diff --git a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java index 48b2b4664..e64ed07d8 100644 --- a/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java +++ b/gateway-ha/src/main/java/io/trino/gateway/ha/handler/ProxyUtils.java @@ -13,6 +13,7 @@ */ package io.trino.gateway.ha.handler; +import com.google.common.collect.ImmutableSet; import com.google.common.io.CharStreams; import io.airlift.log.Logger; import io.trino.gateway.ha.router.TrinoQueryProperties; @@ -24,6 +25,7 @@ import java.net.URI; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -52,6 +54,8 @@ public final class ProxyUtils * capitalization. */ private static final Pattern QUERY_ID_PARAM_PATTERN = Pattern.compile(".*(?:%2F|(?i)query_?id(?-i)=|^)(\\d+_\\d+_\\d+_\\w+).*"); + private static final Set QUERY_STATE_PATH = ImmutableSet.of("queued", "scheduled", "executing"); + private static final String PARTIAL_CANCEL_PATH = "partialCancel"; private ProxyUtils() {} @@ -100,10 +104,10 @@ public static Optional extractQueryIdIfPresent(String path, String query path = path.replace(matchingStatementPath.orElse(V1_QUERY_PATH), ""); String[] tokens = path.split("/"); if (tokens.length >= 2) { - if (tokens[1].equals("queued") - || tokens[1].equals("scheduled") - || tokens[1].equals("executing") - || tokens[1].equals("partialCancel")) { + if (tokens.length >= 3 && QUERY_STATE_PATH.contains(tokens[1])) { + if (tokens.length >= 4 && tokens[2].equals(PARTIAL_CANCEL_PATH)) { + return Optional.of(tokens[3]); + } return Optional.of(tokens[2]); } return Optional.of(tokens[1]); diff --git a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java index b5223d2c9..9ecc3c570 100644 --- a/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java +++ b/gateway-ha/src/test/java/io/trino/gateway/ha/handler/TestQueryIdCachingProxyHandler.java @@ -44,8 +44,14 @@ void testExtractQueryIdFromUrl() throws IOException { List statementPaths = ImmutableList.of("/v1/statement", "/custom/api/statement"); + assertThat(extractQueryIdIfPresent("/v1/statement/queued/20200416_160256_03078_6b4yt/ye6c54db413e65c5de0e99612ab1eaabb8611a8aa/1", null, statementPaths)) + .hasValue("20200416_160256_03078_6b4yt"); + assertThat(extractQueryIdIfPresent("/v1/statement/scheduled/20200416_160256_03078_6b4yt/ye6c54db413e65c5de0e99612ab1eaabb8611a8aa/1", null, statementPaths)) + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/v1/statement/executing/20200416_160256_03078_6b4yt/ya7e884929c67cdf86207a80e7a77ab2166fa2e7b/1368", null, statementPaths)) .hasValue("20200416_160256_03078_6b4yt"); + assertThat(extractQueryIdIfPresent("/v1/statement/executing/partialCancel/20200416_160256_03078_6b4yt/0/yce0e0e038758e454d22d7270de30395e19a28eb6/1", null, statementPaths)) + .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/custom/api/statement/executing/20200416_160256_03078_6b4yt/ya7e884929c67cdf86207a80e7a77ab2166fa2e7b/1368", null, statementPaths)) .hasValue("20200416_160256_03078_6b4yt"); assertThat(extractQueryIdIfPresent("/v1/statement/queued/20200416_160256_03078_6b4yt/y0d7620a6941e78d3950798a1085383234258a566/1", null, statementPaths))