Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.sql.parser.ParsingException;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.facebook.presto.transaction.TransactionId;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -98,7 +99,7 @@ public final class HttpRequestSessionContext
private final boolean clientTransactionSupport;
private final String clientInfo;

public HttpRequestSessionContext(HttpServletRequest servletRequest)
public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOptions sqlParserOptions)
throws WebApplicationException
{
catalog = trimEmptyToNull(servletRequest.getHeader(PRESTO_CATALOG));
Expand Down Expand Up @@ -157,7 +158,7 @@ else if (nameParts.size() == 2) {
this.catalogSessionProperties = catalogSessionProperties.entrySet().stream()
.collect(toImmutableMap(Entry::getKey, entry -> ImmutableMap.copyOf(entry.getValue())));

preparedStatements = parsePreparedStatementsHeaders(servletRequest);
preparedStatements = parsePreparedStatementsHeaders(servletRequest, sqlParserOptions);

String transactionIdHeader = servletRequest.getHeader(PRESTO_TRANSACTION_ID);
clientTransactionSupport = transactionIdHeader != null;
Expand Down Expand Up @@ -223,7 +224,7 @@ private static void assertRequest(boolean expression, String format, Object... a
}
}

private static Map<String, String> parsePreparedStatementsHeaders(HttpServletRequest servletRequest)
private static Map<String, String> parsePreparedStatementsHeaders(HttpServletRequest servletRequest, SqlParserOptions sqlParserOptions)
{
ImmutableMap.Builder<String, String> preparedStatements = ImmutableMap.builder();
for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_PREPARED_STATEMENT))) {
Expand All @@ -241,7 +242,7 @@ private static Map<String, String> parsePreparedStatementsHeaders(HttpServletReq
}

// Validate statement
SqlParser sqlParser = new SqlParser();
SqlParser sqlParser = new SqlParser(sqlParserOptions);
try {
sqlParser.createStatement(sqlString, new ParsingOptions(AS_DOUBLE /* anything */));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.facebook.presto.server.SessionContext;
import com.facebook.presto.spi.ErrorCode;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Ordering;
Expand Down Expand Up @@ -110,16 +111,19 @@ public class QueuedStatementResource
private final ConcurrentMap<QueryId, Query> queries = new ConcurrentHashMap<>();
private final ScheduledExecutorService queryPurger = newSingleThreadScheduledExecutor(threadsNamed("dispatch-query-purger"));

private final SqlParserOptions sqlParserOptions;

@Inject
public QueuedStatementResource(
DispatchManager dispatchManager,
DispatchExecutor executor,
LocalQueryProvider queryResultsProvider)
LocalQueryProvider queryResultsProvider,
SqlParserOptions sqlParserOptions)
{
this.dispatchManager = requireNonNull(dispatchManager, "dispatchManager is null");
this.queryResultsProvider = queryResultsProvider;
this.sqlParserOptions = requireNonNull(sqlParserOptions, "sqlParserOptions is null");

requireNonNull(dispatchManager, "dispatchManager is null");
this.responseExecutor = requireNonNull(executor, "responseExecutor is null").getExecutor();
this.timeoutExecutor = requireNonNull(executor, "timeoutExecutor is null").getScheduledExecutor();

Expand Down Expand Up @@ -166,7 +170,7 @@ public Response postStatement(
throw badRequest(BAD_REQUEST, "SQL statement is empty");
}

SessionContext sessionContext = new HttpRequestSessionContext(servletRequest);
SessionContext sessionContext = new HttpRequestSessionContext(servletRequest, sqlParserOptions);
Query query = new Query(statement, sessionContext, dispatchManager, queryResultsProvider, timeoutExecutor);
queries.put(query.getQueryId(), query);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.security.SelectedRole;
import com.facebook.presto.sql.parser.IdentifierSymbol;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;
Expand All @@ -24,6 +26,7 @@

import java.io.UnsupportedEncodingException;
import java.net.URLEncoder;
import java.util.EnumSet;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.HASH_PARTITION_COUNT;
Expand Down Expand Up @@ -67,7 +70,7 @@ public void testSessionContext()
.build(),
"testRemote");

HttpRequestSessionContext context = new HttpRequestSessionContext(request);
HttpRequestSessionContext context = new HttpRequestSessionContext(request, new SqlParserOptions());
assertEquals(context.getSource(), "testSource");
assertEquals(context.getCatalog(), "testCatalog");
assertEquals(context.getSchema(), "testSchema");
Expand Down Expand Up @@ -99,7 +102,28 @@ public void testPreparedStatementsHeaderDoesNotParse()
.put(PRESTO_PREPARED_STATEMENT, "query1=abcdefg")
.build(),
"testRemote");
new HttpRequestSessionContext(request);
new HttpRequestSessionContext(request, new SqlParserOptions());
Comment thread
BlueberryDS marked this conversation as resolved.
Outdated
}

@Test
public void testPreparedStatementsSpecialCharacters()
{
HttpServletRequest request = new MockHttpServletRequest(
ImmutableListMultimap.<String, String>builder()
.put(PRESTO_USER, "testUser")
.put(PRESTO_SOURCE, "testSource")
.put(PRESTO_CATALOG, "testCatalog")
.put(PRESTO_SCHEMA, "testSchema")
.put(PRESTO_LANGUAGE, "zh-TW")
.put(PRESTO_TIME_ZONE, "Asia/Taipei")
.put(PRESTO_CLIENT_INFO, "null")
.put(PRESTO_PREPARED_STATEMENT, "query1=select * from tbl:ns")
.build(),
"testRemote");
SqlParserOptions options = new SqlParserOptions();
options.allowIdentifierSymbol(EnumSet.allOf(IdentifierSymbol.class));

new HttpRequestSessionContext(request, options);
}

@Test
Expand Down Expand Up @@ -127,7 +151,7 @@ public void testExtraCredentials()
.build(),
"testRemote");

HttpRequestSessionContext context = new HttpRequestSessionContext(request);
HttpRequestSessionContext context = new HttpRequestSessionContext(request, new SqlParserOptions());
assertEquals(
context.getIdentity().getExtraCredentials(),
ImmutableMap.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.security.AllowAllAccessControl;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.sql.SqlEnvironmentConfig;
import com.facebook.presto.sql.parser.SqlParserOptions;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
Expand Down Expand Up @@ -66,7 +67,7 @@ public class TestQuerySessionSupplier
@Test
public void testCreateSession()
{
HttpRequestSessionContext context = new HttpRequestSessionContext(TEST_REQUEST);
HttpRequestSessionContext context = new HttpRequestSessionContext(TEST_REQUEST, new SqlParserOptions());
QuerySessionSupplier sessionSupplier = new QuerySessionSupplier(
createTestTransactionManager(),
new AllowAllAccessControl(),
Expand Down Expand Up @@ -103,7 +104,7 @@ public void testEmptyClientTags()
.put(PRESTO_USER, "testUser")
.build(),
"remoteAddress");
HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1);
HttpRequestSessionContext context1 = new HttpRequestSessionContext(request1, new SqlParserOptions());
assertEquals(context1.getClientTags(), ImmutableSet.of());

HttpServletRequest request2 = new MockHttpServletRequest(
Expand All @@ -112,7 +113,7 @@ public void testEmptyClientTags()
.put(PRESTO_CLIENT_TAGS, "")
.build(),
"remoteAddress");
HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2);
HttpRequestSessionContext context2 = new HttpRequestSessionContext(request2, new SqlParserOptions());
assertEquals(context2.getClientTags(), ImmutableSet.of());
}

Expand All @@ -125,7 +126,7 @@ public void testInvalidTimeZone()
.put(PRESTO_TIME_ZONE, "unknown_timezone")
.build(),
"testRemote");
HttpRequestSessionContext context = new HttpRequestSessionContext(request);
HttpRequestSessionContext context = new HttpRequestSessionContext(request, new SqlParserOptions());
QuerySessionSupplier sessionSupplier = new QuerySessionSupplier(
createTestTransactionManager(),
new AllowAllAccessControl(),
Expand Down