Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@
import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER;
import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.base.Strings.isNullOrEmpty;
import static com.google.common.base.Strings.nullToEmpty;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static com.google.common.net.HttpHeaders.X_FORWARDED_FOR;
import static java.lang.String.format;

public final class HttpRequestSessionContext
Expand Down Expand Up @@ -114,7 +116,7 @@ public HttpRequestSessionContext(HttpServletRequest servletRequest)
source = servletRequest.getHeader(PRESTO_SOURCE);
traceToken = Optional.ofNullable(trimEmptyToNull(servletRequest.getHeader(PRESTO_TRACE_TOKEN)));
userAgent = servletRequest.getHeader(USER_AGENT);
remoteUserAddress = servletRequest.getRemoteAddr();
remoteUserAddress = !isNullOrEmpty(servletRequest.getHeader(X_FORWARDED_FOR)) ? servletRequest.getHeader(X_FORWARDED_FOR) : servletRequest.getRemoteAddr();
timeZoneId = servletRequest.getHeader(PRESTO_TIME_ZONE);
language = servletRequest.getHeader(PRESTO_LANGUAGE);
clientInfo = servletRequest.getHeader(PRESTO_CLIENT_INFO);
Expand Down Expand Up @@ -161,6 +163,123 @@ else if (nameParts.size() == 2) {
transactionId = parseTransactionId(transactionIdHeader);
}

private static List<String> splitSessionHeader(Enumeration<String> headers)
{
Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
return Collections.list(headers).stream()
.map(splitter::splitToList)
.flatMap(Collection::stream)
.collect(toImmutableList());
}

private static Map<String, String> parseSessionHeaders(HttpServletRequest servletRequest)
{
return parseProperty(servletRequest, PRESTO_SESSION);
}

private static Map<String, SelectedRole> parseRoleHeaders(HttpServletRequest servletRequest)
{
ImmutableMap.Builder<String, SelectedRole> roles = ImmutableMap.builder();
for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_ROLE))) {
List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header);
assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_ROLE);
roles.put(nameValue.get(0), SelectedRole.valueOf(urlDecode(nameValue.get(1))));
}
return roles.build();
}

private static Map<String, String> parseExtraCredentials(HttpServletRequest servletRequest)
{
return parseProperty(servletRequest, PRESTO_EXTRA_CREDENTIAL);
}

private static Map<String, String> parseProperty(HttpServletRequest servletRequest, String headerName)
{
Map<String, String> properties = new HashMap<>();
for (String header : splitSessionHeader(servletRequest.getHeaders(headerName))) {
List<String> nameValue = Splitter.on('=').trimResults().splitToList(header);
assertRequest(nameValue.size() == 2, "Invalid %s header", headerName);
properties.put(nameValue.get(0), nameValue.get(1));
}
return properties;
}

private static void assertRequest(boolean expression, String format, Object... args)
{
if (!expression) {
throw badRequest(format(format, args));
}
}

private static Map<String, String> parsePreparedStatementsHeaders(HttpServletRequest servletRequest)
{
ImmutableMap.Builder<String, String> preparedStatements = ImmutableMap.builder();
for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_PREPARED_STATEMENT))) {
List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header);
assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_PREPARED_STATEMENT);

String statementName;
String sqlString;
try {
statementName = urlDecode(nameValue.get(0));
sqlString = urlDecode(nameValue.get(1));
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage()));
}

// Validate statement
SqlParser sqlParser = new SqlParser();
try {
sqlParser.createStatement(sqlString, new ParsingOptions(AS_DOUBLE /* anything */));
}
catch (ParsingException e) {
throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage()));
}

preparedStatements.put(statementName, sqlString);
}
return preparedStatements.build();
}

private static Optional<TransactionId> parseTransactionId(String transactionId)
{
transactionId = trimEmptyToNull(transactionId);
if (transactionId == null || transactionId.equalsIgnoreCase("none")) {
return Optional.empty();
}
try {
return Optional.of(TransactionId.valueOf(transactionId));
}
catch (Exception e) {
throw badRequest(e.getMessage());
}
}

private static WebApplicationException badRequest(String message)
{
throw new WebApplicationException(Response
.status(Status.BAD_REQUEST)
.type(MediaType.TEXT_PLAIN)
.entity(message)
.build());
}

private static String trimEmptyToNull(String value)
{
return emptyToNull(nullToEmpty(value).trim());
}

private static String urlDecode(String value)
{
try {
return URLDecoder.decode(value, "UTF-8");
}
catch (UnsupportedEncodingException e) {
throw new AssertionError(e);
}
}

@Override
public Identity getIdentity()
{
Expand Down Expand Up @@ -263,47 +382,6 @@ public Optional<String> getTraceToken()
return traceToken;
}

private static List<String> splitSessionHeader(Enumeration<String> headers)
{
Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
return Collections.list(headers).stream()
.map(splitter::splitToList)
.flatMap(Collection::stream)
.collect(toImmutableList());
}

private static Map<String, String> parseSessionHeaders(HttpServletRequest servletRequest)
{
return parseProperty(servletRequest, PRESTO_SESSION);
}

private static Map<String, SelectedRole> parseRoleHeaders(HttpServletRequest servletRequest)
{
ImmutableMap.Builder<String, SelectedRole> roles = ImmutableMap.builder();
for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_ROLE))) {
List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header);
assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_ROLE);
roles.put(nameValue.get(0), SelectedRole.valueOf(urlDecode(nameValue.get(1))));
}
return roles.build();
}

private static Map<String, String> parseExtraCredentials(HttpServletRequest servletRequest)
{
return parseProperty(servletRequest, PRESTO_EXTRA_CREDENTIAL);
}

private static Map<String, String> parseProperty(HttpServletRequest servletRequest, String headerName)
{
Map<String, String> properties = new HashMap<>();
for (String header : splitSessionHeader(servletRequest.getHeaders(headerName))) {
List<String> nameValue = Splitter.on('=').trimResults().splitToList(header);
assertRequest(nameValue.size() == 2, "Invalid %s header", headerName);
properties.put(nameValue.get(0), nameValue.get(1));
}
return properties;
}

private Set<String> parseClientTags(HttpServletRequest servletRequest)
{
Splitter splitter = Splitter.on(',').trimResults().omitEmptyStrings();
Expand Down Expand Up @@ -341,80 +419,4 @@ private ResourceEstimates parseResourceEstimate(HttpServletRequest servletReques

return builder.build();
}

private static void assertRequest(boolean expression, String format, Object... args)
{
if (!expression) {
throw badRequest(format(format, args));
}
}

private static Map<String, String> parsePreparedStatementsHeaders(HttpServletRequest servletRequest)
{
ImmutableMap.Builder<String, String> preparedStatements = ImmutableMap.builder();
for (String header : splitSessionHeader(servletRequest.getHeaders(PRESTO_PREPARED_STATEMENT))) {
List<String> nameValue = Splitter.on('=').limit(2).trimResults().splitToList(header);
assertRequest(nameValue.size() == 2, "Invalid %s header", PRESTO_PREPARED_STATEMENT);

String statementName;
String sqlString;
try {
statementName = urlDecode(nameValue.get(0));
sqlString = urlDecode(nameValue.get(1));
}
catch (IllegalArgumentException e) {
throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage()));
}

// Validate statement
SqlParser sqlParser = new SqlParser();
try {
sqlParser.createStatement(sqlString, new ParsingOptions(AS_DOUBLE /* anything */));
}
catch (ParsingException e) {
throw badRequest(format("Invalid %s header: %s", PRESTO_PREPARED_STATEMENT, e.getMessage()));
}

preparedStatements.put(statementName, sqlString);
}
return preparedStatements.build();
}

private static Optional<TransactionId> parseTransactionId(String transactionId)
{
transactionId = trimEmptyToNull(transactionId);
if (transactionId == null || transactionId.equalsIgnoreCase("none")) {
return Optional.empty();
}
try {
return Optional.of(TransactionId.valueOf(transactionId));
}
catch (Exception e) {
throw badRequest(e.getMessage());
}
}

private static WebApplicationException badRequest(String message)
{
throw new WebApplicationException(Response
.status(Status.BAD_REQUEST)
.type(MediaType.TEXT_PLAIN)
.entity(message)
.build());
}

private static String trimEmptyToNull(String value)
{
return emptyToNull(nullToEmpty(value).trim());
}

private static String urlDecode(String value)
{
try {
return URLDecoder.decode(value, "UTF-8");
}
catch (UnsupportedEncodingException e) {
throw new AssertionError(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import static com.google.common.net.HttpHeaders.COOKIE;
import static com.google.common.net.HttpHeaders.SET_COOKIE;
import static com.google.common.net.HttpHeaders.USER_AGENT;
import static com.google.common.net.HttpHeaders.X_FORWARDED_FOR;
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
Expand Down Expand Up @@ -189,6 +190,7 @@ private void performRequest(
Request.Builder requestBuilder,
Function<ProxyResponse, Response> responseBuilder)
{
setupXForwardedFor(servletRequest, requestBuilder);
setupBearerToken(servletRequest, requestBuilder);

for (String name : list(servletRequest.getHeaderNames())) {
Expand Down Expand Up @@ -262,6 +264,16 @@ private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder
requestBuilder.addHeader(AUTHORIZATION, "Bearer " + accessToken);
}

private void setupXForwardedFor(HttpServletRequest servletRequest, Request.Builder requestBuilder)
{
StringBuilder xForwardedFor = new StringBuilder();
if (servletRequest.getHeader(X_FORWARDED_FOR) != null) {
xForwardedFor.append(servletRequest.getHeader(X_FORWARDED_FOR) + ",");
}
xForwardedFor.append(servletRequest.getRemoteAddr());
requestBuilder.addHeader(X_FORWARDED_FOR, xForwardedFor.toString());
}

private static <T> T handleProxyException(Request request, ProxyException e)
{
log.warn(e, "Proxy request failed: %s %s", request.getMethod(), request.getUri());
Expand Down