Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
import jakarta.ws.rs.container.ContainerRequestContext;
import jakarta.ws.rs.core.Response;

import java.security.Key;
import javax.crypto.SecretKey;

import java.time.ZonedDateTime;
import java.util.Date;

Expand All @@ -49,7 +50,7 @@ public class InternalAuthenticationManager

private static final String TRINO_INTERNAL_BEARER = "X-Trino-Internal-Bearer";

private final Key hmac;
private final SecretKey hmac;
private final String nodeId;
private final JwtParser jwtParser;

Expand Down Expand Up @@ -82,7 +83,7 @@ public InternalAuthenticationManager(String sharedSecret, String nodeId)
requireNonNull(nodeId, "nodeId is null");
this.hmac = hmacShaKeyFor(Hashing.sha256().hashString(sharedSecret, UTF_8).asBytes());
this.nodeId = nodeId;
this.jwtParser = newJwtParserBuilder().setSigningKey(hmac).build();
this.jwtParser = newJwtParserBuilder().verifyWith(hmac).build();
}

public static boolean isInternalRequest(ContainerRequestContext request)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,36 @@
import com.google.common.base.CharMatcher;
import com.google.inject.Inject;
import io.airlift.security.pem.PemReader;
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.JweHeader;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.Locator;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.security.MacAlgorithm;
import io.jsonwebtoken.security.SecureDigestAlgorithm;
import io.jsonwebtoken.security.SecurityException;

import javax.crypto.spec.SecretKeySpec;
import javax.crypto.SecretKey;

import java.io.File;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.Key;
import java.security.PublicKey;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;

import static com.google.common.base.CharMatcher.inRange;
import static com.google.common.io.Files.asCharSource;
import static io.jsonwebtoken.security.Keys.hmacShaKeyFor;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.util.Base64.getMimeDecoder;
import static java.util.Objects.requireNonNull;

public class FileSigningKeyResolver
implements SigningKeyResolver
public class FileSigningKeyLocator
implements Locator<Key>
{
private static final String DEFAULT_KEY = "default-key";
private static final CharMatcher INVALID_KID_CHARS = inRange('a', 'z').or(inRange('A', 'Z')).or(inRange('0', '9')).or(CharMatcher.anyOf("_-")).negate();
Expand All @@ -51,15 +56,14 @@ public class FileSigningKeyResolver
private final ConcurrentMap<String, LoadedKey> keys = new ConcurrentHashMap<>();

@Inject
public FileSigningKeyResolver(JwtAuthenticatorConfig config)
public FileSigningKeyLocator(JwtAuthenticatorConfig config)
{
this(config.getKeyFile());
}

public FileSigningKeyResolver(String keyFile)
public FileSigningKeyLocator(String keyFile)
{
this.keyFile = requireNonNull(keyFile, "keyFile is null");

if (keyFile.contains(KEY_ID_VARIABLE)) {
this.staticKey = null;
}
Expand All @@ -69,33 +73,28 @@ public FileSigningKeyResolver(String keyFile)
}

@Override
public Key resolveSigningKey(JwsHeader header, Claims claims)
{
return getKey(header);
}

@Override
public Key resolveSigningKey(JwsHeader header, byte[] plaintext)
public Key locate(Header header)
{
return getKey(header);
return switch (header) {
case JwsHeader jwsHeader -> getKey(jwsHeader.getKeyId(), jwsHeader.getAlgorithm());
case JweHeader jweHeader -> getKey(jweHeader.getKeyId(), jweHeader.getAlgorithm());
default -> throw new UnsupportedJwtException("Cannot locate key for header: %s".formatted(header.getType()));
};
}

private Key getKey(JwsHeader header)
private Key getKey(String keyId, String algorithm)
{
SignatureAlgorithm algorithm = SignatureAlgorithm.forName(header.getAlgorithm());

SecureDigestAlgorithm<?, ?> secureDigestAlgorithm = Jwts.SIG.get().forKey(algorithm);
if (staticKey != null) {
return staticKey.getKey(algorithm);
return staticKey.getKey(secureDigestAlgorithm);
}

String keyId = getKeyId(header);
LoadedKey key = keys.computeIfAbsent(keyId, this::loadKey);
return key.getKey(algorithm);
LoadedKey key = keys.computeIfAbsent(getKeyId(keyId), this::loadKey);
return key.getKey(secureDigestAlgorithm);
}

private static String getKeyId(JwsHeader header)
private static String getKeyId(String keyId)
{
String keyId = header.getKeyId();
if (keyId == null) {
// allow for migration from system not using kid
return DEFAULT_KEY;
Expand Down Expand Up @@ -135,8 +134,8 @@ private static LoadedKey loadKeyFile(File file)

// try to load the key as a base64 encoded HMAC key
try {
byte[] rawKey = getMimeDecoder().decode(data.getBytes(US_ASCII));
return new LoadedKey(rawKey);
SecretKey hmacKey = hmacShaKeyFor(getMimeDecoder().decode(data.getBytes(US_ASCII)));
return new LoadedKey(hmacKey);
}
catch (RuntimeException e) {
throw new SecurityException("Unable to decode HMAC signing key", e);
Expand All @@ -145,28 +144,28 @@ private static LoadedKey loadKeyFile(File file)

private static class LoadedKey
{
private final Key publicKey;
private final byte[] hmacKey;
private final PublicKey publicKey;
private final SecretKey secretKey;

public LoadedKey(Key publicKey)
public LoadedKey(PublicKey publicKey)
{
this.publicKey = requireNonNull(publicKey, "publicKey is null");
this.hmacKey = null;
this.secretKey = null;
}

public LoadedKey(byte[] hmacKey)
public LoadedKey(SecretKey secretKey)
{
this.hmacKey = requireNonNull(hmacKey, "hmacKey is null");
this.secretKey = requireNonNull(secretKey, "secretKey is null");
this.publicKey = null;
}

public Key getKey(SignatureAlgorithm algorithm)
public Key getKey(SecureDigestAlgorithm<?, ?> algorithm)
{
if (algorithm.isHmac()) {
if (hmacKey == null) {
if (algorithm instanceof MacAlgorithm) {
if (secretKey == null) {
throw new UnsupportedJwtException(format("JWT is signed with %s, but no HMAC key is configured", algorithm));
}
return new SecretKeySpec(hmacKey, algorithm.getJcaName());
return secretKey;
}

if (publicKey == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,39 @@
*/
package io.trino.server.security.jwt;

import io.jsonwebtoken.Claims;
import io.jsonwebtoken.Header;
import io.jsonwebtoken.JweHeader;
import io.jsonwebtoken.JwsHeader;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.Locator;
import io.jsonwebtoken.UnsupportedJwtException;
import io.jsonwebtoken.security.SecurityException;

import java.security.Key;

import static java.util.Objects.requireNonNull;

public class JwkSigningKeyResolver
implements SigningKeyResolver
public class JwkSigningKeyLocator
implements Locator<Key>
{
private final JwkService keys;

public JwkSigningKeyResolver(JwkService keys)
public JwkSigningKeyLocator(JwkService keys)
{
this.keys = requireNonNull(keys, "keys is null");
}

@Override
public Key resolveSigningKey(JwsHeader header, Claims claims)
public Key locate(Header header)
{
return getKey(header);
return switch (header) {
case JwsHeader jwsHeader -> getKey(jwsHeader.getKeyId());
case JweHeader jweHeader -> getKey(jweHeader.getKeyId());
default -> throw new UnsupportedJwtException("Cannot locate key for header: %s".formatted(header.getType()));
};
}

@Override
public Key resolveSigningKey(JwsHeader header, byte[] plaintext)
{
return getKey(header);
}

private Key getKey(JwsHeader header)
private Key getKey(String keyId)
{
String keyId = header.getKeyId();
if (keyId == null) {
throw new SecurityException("Key ID is required");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import io.jsonwebtoken.Claims;
import io.jsonwebtoken.JwtParser;
import io.jsonwebtoken.JwtParserBuilder;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.Locator;
import io.trino.server.security.AbstractBearerAuthenticator;
import io.trino.server.security.AuthenticationException;
import io.trino.server.security.UserMapping;
Expand All @@ -26,6 +26,7 @@
import io.trino.spi.security.Identity;
import jakarta.ws.rs.container.ContainerRequestContext;

import java.security.Key;
import java.util.Collection;
import java.util.Optional;

Expand All @@ -43,13 +44,13 @@ public class JwtAuthenticator
private final Optional<String> requiredAudience;

@Inject
public JwtAuthenticator(JwtAuthenticatorConfig config, @ForJwt SigningKeyResolver signingKeyResolver)
public JwtAuthenticator(JwtAuthenticatorConfig config, @ForJwt Locator<Key> signingKeyLocator)
{
principalField = config.getPrincipalField();
requiredAudience = Optional.ofNullable(config.getRequiredAudience());

JwtParserBuilder jwtParser = newJwtParserBuilder()
.setSigningKeyResolver(signingKeyResolver);
.keyLocator(signingKeyLocator);

if (config.getRequiredIssuer() != null) {
jwtParser.requireIssuer(config.getRequiredIssuer());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import com.google.inject.TypeLiteral;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.airlift.http.client.HttpClient;
import io.jsonwebtoken.SigningKeyResolver;
import io.jsonwebtoken.Locator;

import java.net.URI;
import java.security.Key;

import static io.airlift.configuration.ConditionalModule.conditionalModule;
import static io.airlift.configuration.ConfigBinder.configBinder;
Expand All @@ -39,7 +41,7 @@ protected void setup(Binder binder)
JwtAuthenticatorConfig.class,
JwtAuthenticatorSupportModule::isHttp,
new JwkModule(),
jwkBinder -> jwkBinder.bind(SigningKeyResolver.class).annotatedWith(ForJwt.class).to(FileSigningKeyResolver.class).in(Scopes.SINGLETON)));
jwkBinder -> jwkBinder.bind(new TypeLiteral<Locator<Key>>() {}).annotatedWith(ForJwt.class).to(FileSigningKeyLocator.class).in(Scopes.SINGLETON)));
}

private static boolean isHttp(JwtAuthenticatorConfig config)
Expand Down Expand Up @@ -67,9 +69,9 @@ public static JwkService createJwkService(JwtAuthenticatorConfig config, @ForJwt
@Provides
@Singleton
@ForJwt
public static SigningKeyResolver createJwkSigningKeyResolver(@ForJwt JwkService jwkService)
public static Locator<Key> createJwkSigningKeyLocator(@ForJwt JwkService jwkService)
{
return new JwkSigningKeyResolver(jwkService);
return new JwkSigningKeyLocator(jwkService);
}

// this module can be added multiple times, and this prevents multiple processing by Guice
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
import io.trino.server.ui.OAuthWebUiCookie;
import jakarta.ws.rs.core.Response;

import javax.crypto.SecretKey;

import java.io.IOException;
import java.net.URI;
import java.security.Key;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
Expand Down Expand Up @@ -69,7 +70,7 @@ public class OAuth2Service
private final String failureHtml;

private final TemporalAmount challengeTimeout;
private final Key stateHmac;
private final SecretKey stateHmac;
private final JwtParser jwtParser;

private final OAuth2TokenHandler tokenHandler;
Expand All @@ -96,7 +97,7 @@ public OAuth2Service(
.map(key -> sha256().hashString(key, UTF_8).asBytes())
.orElseGet(() -> secureRandomBytes(32)));
this.jwtParser = newJwtParserBuilder()
.setSigningKey(stateHmac)
.verifyWith(stateHmac)
.requireAudience(STATE_AUDIENCE_UI)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import jakarta.ws.rs.core.UriBuilder;
import jakarta.ws.rs.core.UriInfo;

import javax.crypto.SecretKey;

import java.net.URI;
import java.net.URISyntaxException;
import java.security.Key;
Expand Down Expand Up @@ -87,10 +89,10 @@ public FormWebUiAuthenticationFilter(
hmacBytes = new byte[32];
new SecureRandom().nextBytes(hmacBytes);
}
Key hmac = hmacShaKeyFor(hmacBytes);
SecretKey hmac = hmacShaKeyFor(hmacBytes);

this.jwtParser = newJwtParserBuilder()
.setSigningKey(hmac)
.verifyWith(hmac)
.requireAudience(TRINO_UI_AUDIENCE)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,7 @@ private static class TokenServer
private final String issuer = "http://example.com/";
private final String clientId = "clientID";
private final Date tokenExpiration = Date.from(ZonedDateTime.now().plusMinutes(5).toInstant());
private final JwtParser jwtParser = newJwtParserBuilder().setSigningKey(JWK_PUBLIC_KEY).build();
private final JwtParser jwtParser = newJwtParserBuilder().verifyWith(JWK_PUBLIC_KEY).build();
private final Optional<String> principalField;
private final TestingHttpServer jwkServer;
private final String accessToken;
Expand Down
Loading