Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,17 @@ public class JweTokenSerializer
implements TokenPairSerializer
{
private static final Logger LOG = Logger.get(JweTokenSerializer.class);
private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW;
private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512;
private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec();
private static final String ACCESS_TOKEN_KEY = "access_token";
private static final String EXPIRATION_TIME_KEY = "expiration_time";
private static final String REFRESH_TOKEN_KEY = "refresh_token";
private final JweEncryptedSerializer jweSerializer;
private final OAuth2Client client;
private final Clock clock;
private final String issuer;
private final String audience;
private final Duration tokenExpiration;
private final JwtParser parser;
private final AESEncrypter jweEncrypter;
private final AESDecrypter jweDecrypter;
private final String principalField;

public JweTokenSerializer(
Expand All @@ -74,11 +71,8 @@ public JweTokenSerializer(
String principalField,
Clock clock,
Duration tokenExpiration)
throws KeyLengthException, NoSuchAlgorithmException
{
SecretKey secretKey = createKey(config);
this.jweEncrypter = new AESEncrypter(secretKey);
this.jweDecrypter = new AESDecrypter(secretKey);
this.jweSerializer = new JweEncryptedSerializer(getOrGenerateKey(config));
this.client = requireNonNull(client, "client is null");
this.issuer = requireNonNull(issuer, "issuer is null");
this.principalField = requireNonNull(principalField, "principalField is null");
Expand All @@ -100,19 +94,14 @@ public TokenPair deserialize(String token)
requireNonNull(token, "token is null");

try {
JWEObject jwe = JWEObject.parse(token);
jwe.decrypt(jweDecrypter);
Claims claims = parser.parseClaimsJwt(jwe.getPayload().toString()).getBody();
return TokenPair.accessAndRefreshTokens(
Claims claims = parser.parseClaimsJwt(jweSerializer.deserialize(token)).getBody();
return TokenPair.withAccessAndRefreshTokens(
claims.get(ACCESS_TOKEN_KEY, String.class),
claims.get(EXPIRATION_TIME_KEY, Date.class),
claims.get(REFRESH_TOKEN_KEY, String.class));
}
catch (ParseException ex) {
return TokenPair.accessToken(token);
}
catch (JOSEException ex) {
throw new IllegalArgumentException("Decryption failed", ex);
return TokenPair.withAccessToken(token);
}
}

Expand All @@ -121,7 +110,7 @@ public String serialize(TokenPair tokenPair)
{
requireNonNull(tokenPair, "tokenPair is null");

Map<String, Object> claims = client.getClaims(tokenPair.getAccessToken()).orElseThrow(() -> new IllegalArgumentException("Claims are missing"));
Map<String, Object> claims = client.getClaims(tokenPair.accessToken()).orElseThrow(() -> new IllegalArgumentException("Claims are missing"));
if (!claims.containsKey(principalField)) {
throw new IllegalArgumentException(format("%s field is missing", principalField));
}
Expand All @@ -130,37 +119,31 @@ public String serialize(TokenPair tokenPair)
.claim(principalField, claims.get(principalField).toString())
.setAudience(audience)
.setIssuer(issuer)
.claim(ACCESS_TOKEN_KEY, tokenPair.getAccessToken())
.claim(EXPIRATION_TIME_KEY, tokenPair.getExpiration())
.claim(ACCESS_TOKEN_KEY, tokenPair.accessToken())
.claim(EXPIRATION_TIME_KEY, tokenPair.expiration())
.compressWith(COMPRESSION_CODEC);

if (tokenPair.getRefreshToken().isPresent()) {
jwt.claim(REFRESH_TOKEN_KEY, tokenPair.getRefreshToken().orElseThrow());
if (tokenPair.refreshToken().isPresent()) {
jwt.claim(REFRESH_TOKEN_KEY, tokenPair.refreshToken().orElseThrow());
}
else {
LOG.info("No refresh token has been issued, although coordinator expects one. Please check your IdP whether that is correct behaviour");
}

try {
JWEObject jwe = new JWEObject(
new JWEHeader(ALGORITHM, ENCRYPTION_METHOD),
new Payload(jwt.compact()));
jwe.encrypt(jweEncrypter);
return jwe.serialize();
}
catch (JOSEException ex) {
throw new IllegalStateException("Encryption failed", ex);
}
return jweSerializer.serialize(jwt.compact());
}

private static SecretKey createKey(RefreshTokensConfig config)
throws NoSuchAlgorithmException
private static SecretKey getOrGenerateKey(RefreshTokensConfig config)
{
SecretKey signingKey = config.getSecretKey();
if (signingKey == null) {
KeyGenerator generator = KeyGenerator.getInstance("AES");
generator.init(256);
return generator.generateKey();
try {
KeyGenerator generator = KeyGenerator.getInstance("AES");
generator.init(256);
return generator.generateKey();
}
catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}
return signingKey;
}
Expand All @@ -174,4 +157,59 @@ private static CompressionCodec resolveCompressionCodec(Header<?> header)
}
return null;
}

private static class JweEncryptedSerializer
{
private final AESEncrypter jweEncrypter;
private final AESDecrypter jweDecrypter;
private final JWEHeader encryptionHeader;

private JweEncryptedSerializer(SecretKey secretKey)
{
try {
this.encryptionHeader = createEncryptionHeader(secretKey);
this.jweEncrypter = new AESEncrypter(secretKey);
this.jweDecrypter = new AESDecrypter(secretKey);
}
catch (KeyLengthException e) {
throw new RuntimeException(e);
}
}

private JWEHeader createEncryptionHeader(SecretKey key)
{
int keyLength = key.getEncoded().length;
return switch (keyLength) {
case 16 -> new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM);
case 24 -> new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM);
case 32 -> new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM);
default -> throw new IllegalArgumentException("Secret key size must be either 16, 24 or 32 bytes but was %d".formatted(keyLength));
};
}

private String serialize(String payload)
{
try {
JWEObject jwe = new JWEObject(encryptionHeader, new Payload(payload));
jwe.encrypt(jweEncrypter);
return jwe.serialize();
}
catch (JOSEException e) {
throw new RuntimeException(e);
}
}

private String deserialize(String token)
throws ParseException
{
try {
JWEObject jwe = JWEObject.parse(token);
jwe.decrypt(jweDecrypter);
return jwe.getPayload().toString();
}
catch (JOSEException e) {
throw new RuntimeException(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ protected Optional<Identity> createIdentity(String token)
}

TokenPair tokenPair = deserializeToken.get();
if (tokenPair.getExpiration().before(Date.from(Instant.now()))) {
if (tokenPair.expiration().before(Date.from(Instant.now()))) {
return Optional.empty();
}
Optional<Map<String, Object>> claims = client.getClaims(tokenPair.getAccessToken());
Optional<Map<String, Object>> claims = client.getClaims(tokenPair.accessToken());
if (claims.isEmpty()) {
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import io.airlift.configuration.ConfigDescription;
import io.airlift.configuration.ConfigSecuritySensitive;
import io.airlift.units.Duration;
import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.security.Keys;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
import javax.validation.constraints.NotEmpty;

import java.util.Base64;

import static com.google.common.base.Strings.isNullOrEmpty;
import static java.util.concurrent.TimeUnit.HOURS;

Expand Down Expand Up @@ -82,8 +83,7 @@ public RefreshTokensConfig setSecretKey(String key)
if (isNullOrEmpty(key)) {
return this;
}

secretKey = Keys.hmacShaKeyFor(Decoders.BASE64.decode(key));
secretKey = new SecretKeySpec(Base64.getDecoder().decode(key), "AES");
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,30 @@ public interface TokenPairSerializer
@Override
public TokenPair deserialize(String token)
{
return TokenPair.accessToken(token);
return TokenPair.withAccessToken(token);
}

@Override
public String serialize(TokenPair tokenPair)
{
return tokenPair.getAccessToken();
return tokenPair.accessToken();
}
};

TokenPair deserialize(String token);

String serialize(TokenPair tokenPair);

class TokenPair
record TokenPair(String accessToken, Date expiration, Optional<String> refreshToken)
{
private final String accessToken;
private final Date expiration;
private final Optional<String> refreshToken;

private TokenPair(String accessToken, Date expiration, Optional<String> refreshToken)
public TokenPair
{
this.accessToken = requireNonNull(accessToken, "accessToken is nul");
this.expiration = requireNonNull(expiration, "expiration is null");
this.refreshToken = requireNonNull(refreshToken, "refreshToken is null");
requireNonNull(accessToken, "accessToken is nul");
requireNonNull(expiration, "expiration is null");
requireNonNull(refreshToken, "refreshToken is null");
}

public static TokenPair accessToken(String accessToken)
public static TokenPair withAccessToken(String accessToken)
{
return new TokenPair(accessToken, new Date(MAX_VALUE), Optional.empty());
}
Expand All @@ -69,24 +65,9 @@ public static TokenPair fromOAuth2Response(Response tokens)
return new TokenPair(tokens.getAccessToken(), Date.from(tokens.getExpiration()), tokens.getRefreshToken());
}

public static TokenPair accessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken)
public static TokenPair withAccessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken)
{
return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken));
}

public String getAccessToken()
{
return accessToken;
}

public Date getExpiration()
{
return expiration;
}

public Optional<String> getRefreshToken()
{
return refreshToken;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public Optional<UUID> refreshToken(TokenPair tokenPair)
{
requireNonNull(tokenPair, "tokenPair is null");

Optional<String> refreshToken = tokenPair.getRefreshToken();
Optional<String> refreshToken = tokenPair.refreshToken();
if (refreshToken.isPresent()) {
UUID refreshingId = UUID.randomUUID();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ private Optional<TokenPair> getTokenPair(ContainerRequestContext request)

private boolean tokenNotExpired(TokenPair tokenPair)
{
return tokenPair.getExpiration().after(Date.from(Instant.now()));
return tokenPair.expiration().after(Date.from(Instant.now()));
}

private Optional<Map<String, Object>> getAccessTokenClaims(TokenPair tokenPair)
{
return client.getClaims(tokenPair.getAccessToken());
return client.getClaims(tokenPair.accessToken());
}

private void needAuthentication(ContainerRequestContext request, Optional<TokenPair> tokenPair)
{
Optional<String> refreshToken = tokenPair.flatMap(TokenPair::getRefreshToken);
Optional<String> refreshToken = tokenPair.flatMap(TokenPair::refreshToken);
if (refreshToken.isPresent()) {
try {
redirectForNewToken(request, refreshToken.get());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,8 @@ private void verifyOAuth2Authenticator(boolean webUiEnabled, boolean refreshToke
if (refreshTokensEnabled) {
TokenPairSerializer serializer = server.getInstance(Key.get(TokenPairSerializer.class));
TokenPair tokenPair = serializer.deserialize(getOauthToken(client, bearer.getTokenServer()));
assertEquals(tokenPair.getAccessToken(), tokenServer.getAccessToken());
assertEquals(tokenPair.getRefreshToken(), Optional.of(tokenServer.getRefreshToken()));
assertEquals(tokenPair.accessToken(), tokenServer.getAccessToken());
assertEquals(tokenPair.refreshToken(), Optional.of(tokenServer.getRefreshToken()));
}
else {
assertEquals(getOauthToken(client, bearer.getTokenServer()), tokenServer.getAccessToken());
Expand Down
Loading