diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/ForJwk.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/ForJwt.java similarity index 97% rename from core/trino-main/src/main/java/io/trino/server/security/jwt/ForJwk.java rename to core/trino-main/src/main/java/io/trino/server/security/jwt/ForJwt.java index bc4cf9396473..8c34eda6d4b4 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/ForJwk.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/ForJwt.java @@ -26,6 +26,6 @@ @Retention(RUNTIME) @Target({FIELD, PARAMETER, METHOD}) @BindingAnnotation -public @interface ForJwk +public @interface ForJwt { } diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkService.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkService.java index db72ef0220f8..61880221bbb9 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkService.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkService.java @@ -25,7 +25,6 @@ import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import javax.annotation.processing.Generated; -import javax.inject.Inject; import java.io.IOException; import java.net.URI; @@ -55,8 +54,7 @@ public final class JwkService @Generated("this") private Closer closer; - @Inject - public JwkService(@ForJwk URI address, @ForJwk HttpClient httpClient) + public JwkService(URI address, HttpClient httpClient) { this(address, httpClient, new Duration(15, TimeUnit.MINUTES)); } diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkSigningKeyResolver.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkSigningKeyResolver.java index 7b362bce035d..51ca571913c8 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkSigningKeyResolver.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwkSigningKeyResolver.java @@ -18,8 +18,6 @@ import io.jsonwebtoken.SigningKeyResolver; import io.jsonwebtoken.security.SecurityException; -import javax.inject.Inject; - import java.security.Key; import static java.util.Objects.requireNonNull; @@ -29,7 +27,6 @@ public class JwkSigningKeyResolver { private final JwkService keys; - @Inject public JwkSigningKeyResolver(JwkService keys) { this.keys = requireNonNull(keys, "keys is null"); diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java index 5989470f0384..c7de805dfede 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java @@ -39,7 +39,7 @@ public class JwtAuthenticator private final UserMapping userMapping; @Inject - public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signingKeyResolver) + public JwtAuthenticator(JwtAuthenticatorConfig config, @ForJwt SigningKeyResolver signingKeyResolver) { principalField = config.getPrincipalField(); diff --git a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java index 71e2067d83f5..bf4801d74aba 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java +++ b/core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticatorSupportModule.java @@ -18,6 +18,7 @@ import com.google.inject.Provides; import com.google.inject.Scopes; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.airlift.http.client.HttpClient; import io.jsonwebtoken.SigningKeyResolver; import javax.inject.Singleton; @@ -39,7 +40,7 @@ protected void setup(Binder binder) JwtAuthenticatorConfig.class, JwtAuthenticatorSupportModule::isHttp, new JwkModule(), - jwkBinder -> jwkBinder.bind(SigningKeyResolver.class).to(FileSigningKeyResolver.class).in(Scopes.SINGLETON))); + jwkBinder -> jwkBinder.bind(SigningKeyResolver.class).annotatedWith(ForJwt.class).to(FileSigningKeyResolver.class).in(Scopes.SINGLETON))); } private static boolean isHttp(JwtAuthenticatorConfig config) @@ -53,10 +54,8 @@ private static class JwkModule @Override public void configure(Binder binder) { - binder.bind(SigningKeyResolver.class).to(JwkSigningKeyResolver.class).in(Scopes.SINGLETON); - binder.bind(JwkService.class).in(Scopes.SINGLETON); httpClientBinder(binder) - .bindHttpClient("jwk", ForJwk.class) + .bindHttpClient("jwk", ForJwt.class) // Reset HttpClient default configuration to override InternalCommunicationModule changes. // Setting a keystore and/or a truststore for internal communication changes the default SSL configuration // for all clients in the same guice context. This, however, does not make sense for this client which will @@ -72,10 +71,18 @@ public void configure(Binder binder) @Provides @Singleton - @ForJwk - public static URI createJwkAddress(JwtAuthenticatorConfig config) + @ForJwt + public static JwkService createJwkService(JwtAuthenticatorConfig config, @ForJwt HttpClient httpClient) { - return URI.create(config.getKeyFile()); + return new JwkService(URI.create(config.getKeyFile()), httpClient); + } + + @Provides + @Singleton + @ForJwt + public static SigningKeyResolver createJwkSigningKeyResolver(@ForJwt JwkService jwkService) + { + return new JwkSigningKeyResolver(jwkService); } // this module can be added multiple times, and this prevents multiple processing by Guice diff --git a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java index 9d0d41a13e10..7a3b15f4a8bd 100644 --- a/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java +++ b/core/trino-main/src/main/java/io/trino/server/security/oauth2/OAuth2ServiceModule.java @@ -14,13 +14,11 @@ package io.trino.server.security.oauth2; import com.google.inject.Binder; -import com.google.inject.Key; import com.google.inject.Provides; import com.google.inject.Scopes; import io.airlift.configuration.AbstractConfigurationAwareModule; import io.airlift.http.client.HttpClient; import io.jsonwebtoken.SigningKeyResolver; -import io.trino.server.security.jwt.ForJwk; import io.trino.server.security.jwt.JwkService; import io.trino.server.security.jwt.JwkSigningKeyResolver; import io.trino.server.ui.OAuth2WebUiInstalled; @@ -62,18 +60,22 @@ protected void setup(Binder binder) .setTrustStorePath(null) .setTrustStorePassword(null) .setAutomaticHttpsSharedSecret(null)); - // Used by JwkService - binder.bind(HttpClient.class).annotatedWith(ForJwk.class).to(Key.get(HttpClient.class, ForOAuth2.class)); - binder.bind(JwkService.class).in(Scopes.SINGLETON); - binder.bind(SigningKeyResolver.class).annotatedWith(ForOAuth2.class).to(JwkSigningKeyResolver.class).in(Scopes.SINGLETON); } @Provides @Singleton - @ForJwk - public static URI createJwkAddress(OAuth2Config config) + @ForOAuth2 + public static JwkService createJwkService(OAuth2Config config, @ForOAuth2 HttpClient httpClient) { - return URI.create(config.getJwksUrl()); + return new JwkService(URI.create(config.getJwksUrl()), httpClient); + } + + @Provides + @Singleton + @ForOAuth2 + public static SigningKeyResolver createSigningKeyResolver(@ForOAuth2 JwkService jwkService) + { + return new JwkSigningKeyResolver(jwkService); } @Override diff --git a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java index 8923f9f6887a..c9d477a3b036 100644 --- a/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java +++ b/core/trino-main/src/test/java/io/trino/server/security/TestResourceSecurity.java @@ -548,6 +548,7 @@ private void verifyOAuth2Authenticator(boolean webUiEnabled, Optional pr .setProperties(ImmutableMap.builder() .putAll(SECURE_PROPERTIES) .put("web-ui.enabled", String.valueOf(webUiEnabled)) + .put("http-server.authentication.type", "oauth2") .putAll(getOAuth2Properties(tokenServer)) .put("http-server.authentication.oauth2.principal-field", principalField.orElse("sub")) .buildOrThrow()) @@ -673,6 +674,7 @@ public void testOAuth2Groups(Optional> groups) .setProperties(ImmutableMap.builder() .putAll(SECURE_PROPERTIES) .put("web-ui.enabled", "true") + .put("http-server.authentication.type", "oauth2") .putAll(getOAuth2Properties(tokenServer)) .put("http-server.authentication.oauth2.groups-field", GROUPS_CLAIM) .build()) @@ -744,6 +746,53 @@ public static Object[][] groups() }; } + @Test + public void testJwtAndOAuth2AuthenticatorsSeparation() + throws Exception + { + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TokenServer tokenServer = new TokenServer(Optional.empty()); + TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties( + ImmutableMap.builder() + .putAll(SECURE_PROPERTIES) + .put("http-server.authentication.type", "jwt,oauth2") + .put("http-server.authentication.jwt.key-file", jwkServer.getBaseUrl().toString()) + .putAll(getOAuth2Properties(tokenServer)) + .put("web-ui.enabled", "true") + .buildOrThrow()) + .setAdditionalModule(oauth2Module(tokenServer)) + .build()) { + server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + + assertAuthenticationDisabled(httpServerInfo.getHttpUri()); + + OkHttpClient clientWithOAuthToken = client.newBuilder() + .authenticator((route, response) -> response.request().newBuilder() + .header(AUTHORIZATION, "Bearer " + tokenServer.getAccessToken()) + .build()) + .build(); + + assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithOAuthToken); + + String token = Jwts.builder() + .signWith(JWK_PRIVATE_KEY) + .setHeaderParam(JwsHeader.KEY_ID, JWK_KEY_ID) + .setSubject("test-user") + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant())) + .compact(); + + OkHttpClient clientWithJwt = client.newBuilder() + .authenticator((route, response) -> response.request().newBuilder() + .header(AUTHORIZATION, "Bearer " + token) + .build()) + .build(); + assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithJwt); + } + } + private static Module oauth2Module(TokenServer tokenServer) { return binder -> { @@ -757,7 +806,6 @@ private static Module oauth2Module(TokenServer tokenServer) private static Map getOAuth2Properties(TokenServer tokenServer) { return ImmutableMap.builder() - .put("http-server.authentication.type", "oauth2") .put("http-server.authentication.oauth2.issuer", tokenServer.getIssuer()) .put("http-server.authentication.oauth2.jwks-url", tokenServer.getJwksUrl()) .put("http-server.authentication.oauth2.state-key", "test-state-key")