diff --git a/presto-main/src/main/java/io/prestosql/server/InternalAuthenticationManager.java b/presto-main/src/main/java/io/prestosql/server/InternalAuthenticationManager.java new file mode 100644 index 000000000000..7b2dd0bfefa5 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/server/InternalAuthenticationManager.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.server; + +import com.google.common.hash.Hashing; +import io.airlift.http.client.HttpRequestFilter; +import io.airlift.http.client.Request; +import io.airlift.log.Logger; +import io.airlift.node.NodeInfo; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Jws; +import io.jsonwebtoken.JwtException; +import io.jsonwebtoken.JwtParser; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.SignatureAlgorithm; +import io.prestosql.server.security.InternalPrincipal; + +import javax.inject.Inject; +import javax.servlet.http.HttpServletRequest; + +import java.security.Principal; +import java.time.ZonedDateTime; +import java.util.Date; +import java.util.Optional; +import java.util.function.Supplier; + +import static io.airlift.http.client.Request.Builder.fromRequest; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; + +public class InternalAuthenticationManager + implements HttpRequestFilter +{ + private static final Logger log = Logger.get(InternalAuthenticationManager.class); + + private static final String PRESTO_INTERNAL_BEARER = "X-Presto-Internal-Bearer"; + + private final Optional jwtParser; + private final Optional> jwtGenerator; + + @Inject + public InternalAuthenticationManager(InternalCommunicationConfig internalCommunicationConfig, NodeInfo nodeInfo) + { + this(requireNonNull(internalCommunicationConfig, "internalCommunicationConfig is null").getSharedSecret(), nodeInfo.getNodeId()); + } + + public InternalAuthenticationManager(Optional sharedSecret, String nodeId) + { + if (sharedSecret.isPresent()) { + byte[] hmac = Hashing.sha256().hashString(sharedSecret.get(), UTF_8).asBytes(); + this.jwtParser = Optional.of(Jwts.parser().setSigningKey(hmac)); + this.jwtGenerator = Optional.of(() -> generateJwt(hmac, nodeId)); + } + else { + this.jwtParser = Optional.empty(); + this.jwtGenerator = Optional.empty(); + } + } + + private static String generateJwt(byte[] hmac, String nodeId) + { + return Jwts.builder() + .signWith(SignatureAlgorithm.HS256, hmac) + .setSubject(nodeId) + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant())) + .compact(); + } + + public boolean isInternalRequest(HttpServletRequest request) + { + return request.getHeader(PRESTO_INTERNAL_BEARER) != null; + } + + public Principal authenticateInternalRequest(HttpServletRequest request) + { + if (!jwtParser.isPresent()) { + log.error("Internal authentication in not configured"); + return null; + } + + String internalBarer = request.getHeader(PRESTO_INTERNAL_BEARER); + try { + Jws claimsJws = jwtParser.get().parseClaimsJws(internalBarer); + String subject = claimsJws.getBody().getSubject(); + return new InternalPrincipal(subject); + } + catch (JwtException e) { + log.error(e, "Internal authentication failed"); + return null; + } + catch (RuntimeException e) { + throw new RuntimeException("Authentication error", e); + } + } + + @Override + public Request filterRequest(Request request) + { + return jwtGenerator.map(Supplier::get) + .map(jwt -> fromRequest(request) + .addHeader(PRESTO_INTERNAL_BEARER, jwt) + .build()) + .orElse(request); + } +} diff --git a/presto-main/src/main/java/io/prestosql/server/InternalCommunicationConfig.java b/presto-main/src/main/java/io/prestosql/server/InternalCommunicationConfig.java index ed1a548ffb53..5fb938600c21 100644 --- a/presto-main/src/main/java/io/prestosql/server/InternalCommunicationConfig.java +++ b/presto-main/src/main/java/io/prestosql/server/InternalCommunicationConfig.java @@ -16,10 +16,15 @@ import io.airlift.configuration.Config; import io.airlift.configuration.ConfigSecuritySensitive; +import javax.validation.constraints.NotNull; + +import java.util.Optional; + public class InternalCommunicationConfig { public static final String INTERNAL_COMMUNICATION_KERBEROS_ENABLED = "internal-communication.kerberos.enabled"; + private String sharedSecret; private boolean httpsRequired; private String keyStorePath; private String keyStorePassword; @@ -28,6 +33,20 @@ public class InternalCommunicationConfig private boolean kerberosEnabled; private boolean kerberosUseCanonicalHostname = true; + @NotNull + public Optional getSharedSecret() + { + return Optional.ofNullable(sharedSecret); + } + + @ConfigSecuritySensitive + @Config("internal-communication.shared-secret") + public InternalCommunicationConfig setSharedSecret(String sharedSecret) + { + this.sharedSecret = sharedSecret; + return this; + } + public boolean isHttpsRequired() { return httpsRequired; diff --git a/presto-main/src/main/java/io/prestosql/server/InternalCommunicationModule.java b/presto-main/src/main/java/io/prestosql/server/InternalCommunicationModule.java index a846371efab9..c06f83866d98 100644 --- a/presto-main/src/main/java/io/prestosql/server/InternalCommunicationModule.java +++ b/presto-main/src/main/java/io/prestosql/server/InternalCommunicationModule.java @@ -27,6 +27,7 @@ import static com.google.common.base.Verify.verify; import static io.airlift.configuration.ConditionalModule.installModuleIf; import static io.airlift.configuration.ConfigBinder.configBinder; +import static io.airlift.http.client.HttpClientBinder.httpClientBinder; import static io.prestosql.server.InternalCommunicationConfig.INTERNAL_COMMUNICATION_KERBEROS_ENABLED; import static io.prestosql.server.security.KerberosConfig.HTTP_SERVER_AUTHENTICATION_KRB5_KEYTAB; @@ -45,6 +46,8 @@ protected void setup(Binder binder) }); install(installModuleIf(InternalCommunicationConfig.class, InternalCommunicationConfig::isKerberosEnabled, kerberosInternalCommunicationModule())); + binder.bind(InternalAuthenticationManager.class); + httpClientBinder(binder).bindGlobalFilter(InternalAuthenticationManager.class); } private Module kerberosInternalCommunicationModule() diff --git a/presto-main/src/main/java/io/prestosql/server/security/AuthenticationFilter.java b/presto-main/src/main/java/io/prestosql/server/security/AuthenticationFilter.java index ed776c1536d5..4ef2cbd815dd 100644 --- a/presto-main/src/main/java/io/prestosql/server/security/AuthenticationFilter.java +++ b/presto-main/src/main/java/io/prestosql/server/security/AuthenticationFilter.java @@ -17,6 +17,7 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.net.HttpHeaders; +import io.prestosql.server.InternalAuthenticationManager; import javax.inject.Inject; import javax.servlet.Filter; @@ -49,12 +50,14 @@ public class AuthenticationFilter private final List authenticators; private final boolean httpsForwardingEnabled; + private final InternalAuthenticationManager internalAuthenticationManager; @Inject - public AuthenticationFilter(List authenticators, SecurityConfig securityConfig) + public AuthenticationFilter(List authenticators, SecurityConfig securityConfig, InternalAuthenticationManager internalAuthenticationManager) { this.authenticators = ImmutableList.copyOf(requireNonNull(authenticators, "authenticators is null")); this.httpsForwardingEnabled = requireNonNull(securityConfig, "securityConfig is null").getEnableForwardingHttps(); + this.internalAuthenticationManager = requireNonNull(internalAuthenticationManager, "internalAuthenticationManager is null"); } @Override @@ -70,6 +73,16 @@ public void doFilter(ServletRequest servletRequest, ServletResponse servletRespo HttpServletRequest request = (HttpServletRequest) servletRequest; HttpServletResponse response = (HttpServletResponse) servletResponse; + if (internalAuthenticationManager.isInternalRequest(request)) { + Principal principal = internalAuthenticationManager.authenticateInternalRequest(request); + if (principal == null) { + response.sendError(SC_UNAUTHORIZED); + return; + } + nextFilter.doFilter(withPrincipal(request, principal), response); + return; + } + // skip authentication if non-secure or not configured if (!doesRequestSupportAuthentication(request)) { nextFilter.doFilter(request, response); diff --git a/presto-main/src/main/java/io/prestosql/server/security/InternalPrincipal.java b/presto-main/src/main/java/io/prestosql/server/security/InternalPrincipal.java new file mode 100644 index 000000000000..cdddc497d511 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/server/security/InternalPrincipal.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.server.security; + +import java.security.Principal; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public final class InternalPrincipal + implements Principal +{ + private final String name; + + public InternalPrincipal(String name) + { + this.name = requireNonNull(name, "name is null"); + } + + @Override + public String getName() + { + return name; + } + + @Override + public String toString() + { + return name; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + InternalPrincipal that = (InternalPrincipal) o; + return Objects.equals(name, that.name); + } + + @Override + public int hashCode() + { + return Objects.hash(name); + } +} diff --git a/presto-main/src/test/java/io/prestosql/server/TestGenerateTokenFilter.java b/presto-main/src/test/java/io/prestosql/server/TestGenerateTokenFilter.java index 2c8b97236f23..bed58dd0801d 100644 --- a/presto-main/src/test/java/io/prestosql/server/TestGenerateTokenFilter.java +++ b/presto-main/src/test/java/io/prestosql/server/TestGenerateTokenFilter.java @@ -63,9 +63,9 @@ public void setup() // extract the filter List filters = httpClient.getRequestFilters(); - assertEquals(filters.size(), 2); - assertInstanceOf(filters.get(1), GenerateTraceTokenRequestFilter.class); - filter = (GenerateTraceTokenRequestFilter) filters.get(1); + assertEquals(filters.size(), 3); + assertInstanceOf(filters.get(2), GenerateTraceTokenRequestFilter.class); + filter = (GenerateTraceTokenRequestFilter) filters.get(2); } @AfterClass(alwaysRun = true) diff --git a/presto-main/src/test/java/io/prestosql/server/TestInternalCommunicationConfig.java b/presto-main/src/test/java/io/prestosql/server/TestInternalCommunicationConfig.java new file mode 100644 index 000000000000..759217002d0b --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/server/TestInternalCommunicationConfig.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.server; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static io.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static io.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static io.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestInternalCommunicationConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(InternalCommunicationConfig.class) + .setSharedSecret(null) + .setHttpsRequired(false) + .setKeyStorePath(null) + .setKeyStorePassword(null) + .setTrustStorePath(null) + .setTrustStorePassword(null) + .setKerberosEnabled(false) + .setKerberosUseCanonicalHostname(true)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("internal-communication.shared-secret", "secret") + .put("internal-communication.https.required", "true") + .put("internal-communication.https.keystore.path", "key-path") + .put("internal-communication.https.keystore.key", "key-key") + .put("internal-communication.https.truststore.path", "trust-path") + .put("internal-communication.https.truststore.key", "trust-key") + .put("internal-communication.kerberos.enabled", "true") + .put("internal-communication.kerberos.use-canonical-hostname", "false") + .build(); + + InternalCommunicationConfig expected = new InternalCommunicationConfig() + .setSharedSecret("secret") + .setHttpsRequired(true) + .setKeyStorePath("key-path") + .setKeyStorePassword("key-key") + .setTrustStorePath("trust-path") + .setTrustStorePassword("trust-key") + .setKerberosEnabled(true) + .setKerberosUseCanonicalHostname(false); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-testing/src/main/java/io/prestosql/testing/DistributedQueryRunner.java b/presto-testing/src/main/java/io/prestosql/testing/DistributedQueryRunner.java index 399f3e8a9900..d141b4f22cb4 100644 --- a/presto-testing/src/main/java/io/prestosql/testing/DistributedQueryRunner.java +++ b/presto-testing/src/main/java/io/prestosql/testing/DistributedQueryRunner.java @@ -180,6 +180,7 @@ private static TestingPrestoServer createTestingPrestoServer(URI discoveryUri, b { long start = System.nanoTime(); ImmutableMap.Builder propertiesBuilder = ImmutableMap.builder() + .put("internal-communication.shared-secret", "test-secret") .put("query.client.timeout", "10m") .put("exchange.http-client.idle-timeout", "1h") .put("task.max-index-memory", "16kB") // causes index joins to fault load