Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion presto-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@
<artifactId>http-server</artifactId>
</dependency>

<dependency>
<groupId>com.facebook.airlift</groupId>
<artifactId>security</artifactId>
</dependency>

<dependency>
<groupId>com.facebook.airlift</groupId>
<artifactId>jaxrs</artifactId>
Expand Down Expand Up @@ -364,7 +369,6 @@
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-jackson</artifactId>
<scope>runtime</scope>
</dependency>

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,13 @@
import com.facebook.presto.server.SessionContext;
import com.facebook.presto.server.SessionPropertyDefaults;
import com.facebook.presto.server.SessionSupplier;
import com.facebook.presto.server.security.SecurityConfig;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.analyzer.AnalyzerOptions;
import com.facebook.presto.spi.analyzer.QueryPreparerProvider;
import com.facebook.presto.spi.resourceGroups.SelectionContext;
import com.facebook.presto.spi.resourceGroups.SelectionCriteria;
import com.facebook.presto.spi.security.AccessControl;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.sql.analyzer.QueryPreparerProviderManager;
import com.facebook.presto.transaction.TransactionManager;
import com.google.common.util.concurrent.AbstractFuture;
Expand All @@ -57,8 +55,6 @@
import java.util.concurrent.Executor;

import static com.facebook.presto.SystemSessionProperties.getAnalyzerType;
import static com.facebook.presto.security.AccessControlUtils.checkPermissions;
import static com.facebook.presto.security.AccessControlUtils.getAuthorizedIdentity;
import static com.facebook.presto.spi.StandardErrorCode.QUERY_TEXT_TOO_LARGE;
import static com.facebook.presto.util.AnalyzerUtil.createAnalyzerOptions;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -93,7 +89,6 @@ public class DispatchManager

private final QueryManagerStats stats = new QueryManagerStats();

private final SecurityConfig securityConfig;
private final QueryPreparerProviderManager queryPreparerProviderManager;

/**
Expand Down Expand Up @@ -130,7 +125,6 @@ public DispatchManager(
QueryManagerConfig queryManagerConfig,
DispatchExecutor dispatchExecutor,
ClusterStatusSender clusterStatusSender,
SecurityConfig securityConfig,
Optional<ClusterQueryTrackerService> clusterQueryTrackerService)
{
this.queryIdGenerator = requireNonNull(queryIdGenerator, "queryIdGenerator is null");
Expand All @@ -152,8 +146,6 @@ public DispatchManager(
this.clusterStatusSender = requireNonNull(clusterStatusSender, "clusterStatusSender is null");

this.queryTracker = new QueryTracker<>(queryManagerConfig, dispatchExecutor.getScheduledExecutor(), clusterQueryTrackerService);

this.securityConfig = requireNonNull(securityConfig, "securityConfig is null");
}

/**
Expand Down Expand Up @@ -275,14 +267,8 @@ private <C> void createQueryInternal(QueryId queryId, String slug, int retryCoun
throw new PrestoException(QUERY_TEXT_TOO_LARGE, format("Query text length (%s) exceeds the maximum length (%s)", queryLength, maxQueryLength));
}

// check permissions if needed
checkPermissions(accessControl, securityConfig, queryId, sessionContext);

// get authorized identity if possible
Optional<AuthorizedIdentity> authorizedIdentity = getAuthorizedIdentity(accessControl, securityConfig, queryId, sessionContext);

// decode session
session = sessionSupplier.createSession(queryId, sessionContext, warningCollectorFactory, authorizedIdentity);
session = sessionSupplier.createSession(queryId, sessionContext, warningCollectorFactory);

// prepare query
AnalyzerOptions analyzerOptions = createAnalyzerOptions(session, session.getWarningCollector());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.security.SelectedRole;
import com.facebook.presto.spi.session.ResourceEstimates;
Expand Down Expand Up @@ -76,6 +77,7 @@
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRACE_TOKEN;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_TRANSACTION_ID;
import static com.facebook.presto.client.PrestoHeaders.PRESTO_USER;
import static com.facebook.presto.server.security.ServletSecurityUtils.authorizedIdentity;
import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static com.google.common.base.Strings.emptyToNull;
import static com.google.common.base.Strings.isNullOrEmpty;
Expand All @@ -99,6 +101,7 @@ public final class HttpRequestSessionContext
private final String schema;

private final Identity identity;
private final Optional<AuthorizedIdentity> authorizedIdentity;
private final List<X509Certificate> certificates;

private final String source;
Expand Down Expand Up @@ -155,6 +158,7 @@ public HttpRequestSessionContext(HttpServletRequest servletRequest, SqlParserOpt
ImmutableMap.of(),
Optional.empty(),
Optional.empty());
authorizedIdentity = authorizedIdentity(servletRequest);

X509Certificate[] certs = (X509Certificate[]) servletRequest.getAttribute(X509_ATTRIBUTE);
if (certs != null && certs.length > 0) {
Expand Down Expand Up @@ -404,6 +408,12 @@ public Identity getIdentity()
return identity;
}

@Override
public Optional<AuthorizedIdentity> getAuthorizedIdentity()
{
return authorizedIdentity;
}

@Override
public List<X509Certificate> getCertificates()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
import com.facebook.presto.Session;
import com.facebook.presto.execution.warnings.WarningCollectorFactory;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.security.AuthorizedIdentity;

import java.util.Optional;

/**
* Used on workers.
Expand All @@ -27,7 +24,7 @@ public class NoOpSessionSupplier
implements SessionSupplier
{
@Override
public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory, Optional<AuthorizedIdentity> authorizedIdentity)
public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory)
{
throw new UnsupportedOperationException();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.execution.warnings.WarningCollectorFactory;
import com.facebook.presto.metadata.SessionPropertyManager;
import com.facebook.presto.server.security.SecurityConfig;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.SqlFunctionId;
Expand All @@ -39,6 +40,8 @@
import static com.facebook.presto.Session.SessionBuilder;
import static com.facebook.presto.SystemSessionProperties.WARNING_HANDLING;
import static com.facebook.presto.common.type.TimeZoneKey.getTimeZoneKey;
import static com.facebook.presto.security.AccessControlUtils.checkPermissions;
import static com.facebook.presto.security.AccessControlUtils.getAuthorizedIdentity;
import static java.util.Map.Entry;
import static java.util.Objects.requireNonNull;

Expand All @@ -52,44 +55,30 @@ public class QuerySessionSupplier
private final AccessControl accessControl;
private final SessionPropertyManager sessionPropertyManager;
private final Optional<TimeZoneKey> forcedSessionTimeZone;
private final SecurityConfig securityConfig;

@Inject
public QuerySessionSupplier(
TransactionManager transactionManager,
AccessControl accessControl,
SessionPropertyManager sessionPropertyManager,
SqlEnvironmentConfig config)
SqlEnvironmentConfig sqlEnvironmentConfig,
SecurityConfig securityConfig)
{
this.transactionManager = requireNonNull(transactionManager, "transactionManager is null");
this.accessControl = requireNonNull(accessControl, "accessControl is null");
this.sessionPropertyManager = requireNonNull(sessionPropertyManager, "sessionPropertyManager is null");
requireNonNull(config, "config is null");
this.forcedSessionTimeZone = requireNonNull(config.getForcedSessionTimeZone(), "forcedSessionTimeZone is null");
requireNonNull(sqlEnvironmentConfig, "sqlEnvironmentConfig is null");
this.forcedSessionTimeZone = requireNonNull(sqlEnvironmentConfig.getForcedSessionTimeZone(), "forcedSessionTimeZone is null");
this.securityConfig = requireNonNull(securityConfig, "securityConfig is null");
}

@Override
public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory, Optional<AuthorizedIdentity> authorizedIdentity)
public Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory)
{
Identity identity = context.getIdentity();
if (authorizedIdentity.isPresent()) {
identity = new Identity(
identity.getUser(),
identity.getPrincipal(),
identity.getRoles(),
identity.getExtraCredentials(),
identity.getExtraAuthenticators(),
Optional.of(authorizedIdentity.get().getUserName()),
authorizedIdentity.get().getReasonForSelect());
log.info(String.format(
"For query %s, given user is %s, authorized user is %s",
queryId.getId(),
identity.getUser(),
authorizedIdentity.get().getUserName()));
}

SessionBuilder sessionBuilder = Session.builder(sessionPropertyManager)
.setQueryId(queryId)
.setIdentity(identity)
.setIdentity(authenticateIdentity(queryId, context))
.setSource(context.getSource())
.setCatalog(context.getCatalog())
.setSchema(context.getSchema())
Expand Down Expand Up @@ -145,4 +134,20 @@ else if (context.getTimeZoneId() != null) {
}
return session;
}

private Identity authenticateIdentity(QueryId queryId, SessionContext context)
{
checkPermissions(accessControl, securityConfig, queryId, context);
Optional<AuthorizedIdentity> authorizedIdentity = context.getAuthorizedIdentity();
authorizedIdentity = authorizedIdentity.isPresent() ? authorizedIdentity : getAuthorizedIdentity(accessControl, securityConfig, queryId, context);

return authorizedIdentity.map(identity -> new Identity(
context.getIdentity().getUser(),
context.getIdentity().getPrincipal(),
context.getIdentity().getRoles(),
context.getIdentity().getExtraCredentials(),
context.getIdentity().getExtraAuthenticators(),
Optional.of(identity.getUserName()),
identity.getReasonForSelect())).orElseGet(context::getIdentity);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.common.transaction.TransactionId;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.AuthorizedIdentity;
import com.facebook.presto.spi.security.Identity;
import com.facebook.presto.spi.session.ResourceEstimates;
import com.facebook.presto.spi.tracing.Tracer;
Expand All @@ -34,6 +35,11 @@ public interface SessionContext
{
Identity getIdentity();

default Optional<AuthorizedIdentity> getAuthorizedIdentity()
{
return Optional.empty();
}

default List<X509Certificate> getCertificates()
{
return ImmutableList.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
import com.facebook.presto.Session;
import com.facebook.presto.execution.warnings.WarningCollectorFactory;
import com.facebook.presto.spi.QueryId;
import com.facebook.presto.spi.security.AuthorizedIdentity;

import java.util.Optional;

public interface SessionSupplier
{
Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory, Optional<AuthorizedIdentity> authorizedIdentity);
Session createSession(QueryId queryId, SessionContext context, WarningCollectorFactory warningCollectorFactory);
}
Loading