Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@
@Retention(RUNTIME)
@Target({FIELD, PARAMETER, METHOD})
@BindingAnnotation
public @interface ForJwk
public @interface ForJwt
{
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.annotation.processing.Generated;
import javax.inject.Inject;

import java.io.IOException;
import java.net.URI;
Expand Down Expand Up @@ -55,8 +54,7 @@ public final class JwkService
@Generated("this")
private Closer closer;

@Inject
public JwkService(@ForJwk URI address, @ForJwk HttpClient httpClient)
public JwkService(URI address, HttpClient httpClient)
{
this(address, httpClient, new Duration(15, TimeUnit.MINUTES));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.security.SecurityException;

import javax.inject.Inject;

import java.security.Key;

import static java.util.Objects.requireNonNull;
Expand All @@ -29,7 +27,6 @@ public class JwkSigningKeyResolver
{
private final JwkService keys;

@Inject
public JwkSigningKeyResolver(JwkService keys)
{
this.keys = requireNonNull(keys, "keys is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class JwtAuthenticator
private final UserMapping userMapping;

@Inject
public JwtAuthenticator(JwtAuthenticatorConfig config, SigningKeyResolver signingKeyResolver)
public JwtAuthenticator(JwtAuthenticatorConfig config, @ForJwt SigningKeyResolver signingKeyResolver)
{
principalField = config.getPrincipalField();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import com.google.inject.Provides;
import com.google.inject.Scopes;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.http.client.HttpClient;
import io.jsonwebtoken.SigningKeyResolver;

import javax.inject.Singleton;
Expand All @@ -39,7 +40,7 @@ protected void setup(Binder binder)
JwtAuthenticatorConfig.class,
JwtAuthenticatorSupportModule::isHttp,
new JwkModule(),
jwkBinder -> jwkBinder.bind(SigningKeyResolver.class).to(FileSigningKeyResolver.class).in(Scopes.SINGLETON)));
jwkBinder -> jwkBinder.bind(SigningKeyResolver.class).annotatedWith(ForJwt.class).to(FileSigningKeyResolver.class).in(Scopes.SINGLETON)));
}

private static boolean isHttp(JwtAuthenticatorConfig config)
Expand All @@ -53,10 +54,8 @@ private static class JwkModule
@Override
public void configure(Binder binder)
{
binder.bind(SigningKeyResolver.class).to(JwkSigningKeyResolver.class).in(Scopes.SINGLETON);
binder.bind(JwkService.class).in(Scopes.SINGLETON);
httpClientBinder(binder)
.bindHttpClient("jwk", ForJwk.class)
.bindHttpClient("jwk", ForJwt.class)
// Reset HttpClient default configuration to override InternalCommunicationModule changes.
// Setting a keystore and/or a truststore for internal communication changes the default SSL configuration
// for all clients in the same guice context. This, however, does not make sense for this client which will
Expand All @@ -72,10 +71,18 @@ public void configure(Binder binder)

@Provides
@Singleton
@ForJwk
public static URI createJwkAddress(JwtAuthenticatorConfig config)
@ForJwt
public static JwkService createJwkService(JwtAuthenticatorConfig config, @ForJwt HttpClient httpClient)
{
return URI.create(config.getKeyFile());
return new JwkService(URI.create(config.getKeyFile()), httpClient);
}

@Provides
@Singleton
@ForJwt
public static SigningKeyResolver createJwkSigningKeyResolver(@ForJwt JwkService jwkService)
{
return new JwkSigningKeyResolver(jwkService);
}

// this module can be added multiple times, and this prevents multiple processing by Guice
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
package io.trino.server.security.oauth2;

import com.google.inject.Binder;
import com.google.inject.Key;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.http.client.HttpClient;
import io.jsonwebtoken.SigningKeyResolver;
import io.trino.server.security.jwt.ForJwk;
import io.trino.server.security.jwt.JwkService;
import io.trino.server.security.jwt.JwkSigningKeyResolver;
import io.trino.server.ui.OAuth2WebUiInstalled;
Expand Down Expand Up @@ -62,18 +60,22 @@ protected void setup(Binder binder)
.setTrustStorePath(null)
.setTrustStorePassword(null)
.setAutomaticHttpsSharedSecret(null));
// Used by JwkService
binder.bind(HttpClient.class).annotatedWith(ForJwk.class).to(Key.get(HttpClient.class, ForOAuth2.class));
binder.bind(JwkService.class).in(Scopes.SINGLETON);
binder.bind(SigningKeyResolver.class).annotatedWith(ForOAuth2.class).to(JwkSigningKeyResolver.class).in(Scopes.SINGLETON);
}

@Provides
@Singleton
@ForJwk
public static URI createJwkAddress(OAuth2Config config)
@ForOAuth2
public static JwkService createJwkService(OAuth2Config config, @ForOAuth2 HttpClient httpClient)
{
return URI.create(config.getJwksUrl());
return new JwkService(URI.create(config.getJwksUrl()), httpClient);
}

@Provides
@Singleton
@ForOAuth2
public static SigningKeyResolver createSigningKeyResolver(@ForOAuth2 JwkService jwkService)
{
return new JwkSigningKeyResolver(jwkService);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ private void verifyOAuth2Authenticator(boolean webUiEnabled, Optional<String> pr
.setProperties(ImmutableMap.<String, String>builder()
.putAll(SECURE_PROPERTIES)
.put("web-ui.enabled", String.valueOf(webUiEnabled))
.put("http-server.authentication.type", "oauth2")
.putAll(getOAuth2Properties(tokenServer))
.put("http-server.authentication.oauth2.principal-field", principalField.orElse("sub"))
.buildOrThrow())
Expand Down Expand Up @@ -673,6 +674,7 @@ public void testOAuth2Groups(Optional<Set<String>> groups)
.setProperties(ImmutableMap.<String, String>builder()
.putAll(SECURE_PROPERTIES)
.put("web-ui.enabled", "true")
.put("http-server.authentication.type", "oauth2")
.putAll(getOAuth2Properties(tokenServer))
.put("http-server.authentication.oauth2.groups-field", GROUPS_CLAIM)
.build())
Expand Down Expand Up @@ -744,6 +746,53 @@ public static Object[][] groups()
};
}

@Test
public void testJwtAndOAuth2AuthenticatorsSeparation()
throws Exception
{
TestingHttpServer jwkServer = createTestingJwkServer();
jwkServer.start();
try (TokenServer tokenServer = new TokenServer(Optional.empty());
TestingTrinoServer server = TestingTrinoServer.builder()
.setProperties(
ImmutableMap.<String, String>builder()
.putAll(SECURE_PROPERTIES)
.put("http-server.authentication.type", "jwt,oauth2")
.put("http-server.authentication.jwt.key-file", jwkServer.getBaseUrl().toString())
.putAll(getOAuth2Properties(tokenServer))
.put("web-ui.enabled", "true")
.buildOrThrow())
.setAdditionalModule(oauth2Module(tokenServer))
.build()) {
server.getInstance(Key.get(AccessControlManager.class)).addSystemAccessControl(TestSystemAccessControl.NO_IMPERSONATION);
HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class));

assertAuthenticationDisabled(httpServerInfo.getHttpUri());

OkHttpClient clientWithOAuthToken = client.newBuilder()
.authenticator((route, response) -> response.request().newBuilder()
.header(AUTHORIZATION, "Bearer " + tokenServer.getAccessToken())
.build())
.build();

assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithOAuthToken);

String token = Jwts.builder()
.signWith(JWK_PRIVATE_KEY)
.setHeaderParam(JwsHeader.KEY_ID, JWK_KEY_ID)
.setSubject("test-user")
.setExpiration(Date.from(ZonedDateTime.now().plusMinutes(5).toInstant()))
.compact();

OkHttpClient clientWithJwt = client.newBuilder()
.authenticator((route, response) -> response.request().newBuilder()
.header(AUTHORIZATION, "Bearer " + token)
.build())
.build();
assertAuthenticationAutomatic(httpServerInfo.getHttpsUri(), clientWithJwt);
}
}

private static Module oauth2Module(TokenServer tokenServer)
{
return binder -> {
Expand All @@ -757,7 +806,6 @@ private static Module oauth2Module(TokenServer tokenServer)
private static Map<String, String> getOAuth2Properties(TokenServer tokenServer)
{
return ImmutableMap.<String, String>builder()
.put("http-server.authentication.type", "oauth2")
.put("http-server.authentication.oauth2.issuer", tokenServer.getIssuer())
.put("http-server.authentication.oauth2.jwks-url", tokenServer.getJwksUrl())
.put("http-server.authentication.oauth2.state-key", "test-state-key")
Expand Down