Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import com.nimbusds.jose.EncryptionMethod;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JWEAlgorithm;
import com.nimbusds.jose.JWEDecrypter;
import com.nimbusds.jose.JWEEncrypter;
import com.nimbusds.jose.JWEHeader;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.KeyLengthException;
Expand Down Expand Up @@ -49,8 +51,8 @@
public class JweTokenSerializer
implements TokenPairSerializer
{
private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW;
private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512;
private final JWEHeader encryptionHeader;

private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec();
private static final String ACCESS_TOKEN_KEY = "access_token";
private static final String EXPIRATION_TIME_KEY = "expiration_time";
Expand All @@ -61,8 +63,8 @@ public class JweTokenSerializer
private final String audience;
private final Duration tokenExpiration;
private final JwtParser parser;
private final AESEncrypter jweEncrypter;
private final AESDecrypter jweDecrypter;
private final JWEEncrypter jweEncrypter;
private final JWEDecrypter jweDecrypter;
private final String principalField;

public JweTokenSerializer(
Expand All @@ -84,6 +86,7 @@ public JweTokenSerializer(
this.audience = requireNonNull(audience, "issuer is null");
this.clock = requireNonNull(clock, "clock is null");
this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null");
this.encryptionHeader = createEncryptionHeader(secretKey);

this.parser = newJwtParserBuilder()
.setClock(() -> Date.from(clock.instant()))
Expand All @@ -93,11 +96,26 @@ public JweTokenSerializer(
.build();
}

private JWEHeader createEncryptionHeader(SecretKey key)
{
int keyLength = key.getEncoded().length;
switch (keyLength) {
case 16:
return new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM);
case 24:
return new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM);
case 32:
return new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM);
Comment thread
yhwang marked this conversation as resolved.
default:
throw new IllegalArgumentException(
String.format("Secret key size must be either 16, 24 or 32 bytes but was %d", keyLength));
}
}

@Override
public TokenPair deserialize(String token)
{
requireNonNull(token, "token is null");

try {
JWEObject jwe = JWEObject.parse(token);
jwe.decrypt(jweDecrypter);
Expand Down Expand Up @@ -139,9 +157,7 @@ public String serialize(TokenPair tokenPair)
.compressWith(COMPRESSION_CODEC);

try {
JWEObject jwe = new JWEObject(
new JWEHeader(ALGORITHM, ENCRYPTION_METHOD),
new Payload(jwt.compact()));
JWEObject jwe = new JWEObject(encryptionHeader, new Payload(jwt.compact()));
jwe.encrypt(jweEncrypter);
return jwe.serialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
import com.facebook.airlift.configuration.ConfigSecuritySensitive;
import com.facebook.airlift.units.Duration;
import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.security.Keys;
import jakarta.validation.constraints.NotEmpty;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

import static com.google.common.base.Strings.isNullOrEmpty;
import static java.util.concurrent.TimeUnit.HOURS;
Expand Down Expand Up @@ -85,7 +85,7 @@ public RefreshTokensConfig setSecretKey(String key)
return this;
}

secretKey = Keys.hmacShaKeyFor(Decoders.BASE64.decode(key));
secretKey = new SecretKeySpec(Decoders.BASE64.decode(key), "AES");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The corresponding Trino PR had a test, is it applicable here, and if so could you add it? trinodb/trino@3c687c3

return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,10 @@ public Optional<String> getRefreshToken()
{
return refreshToken;
}

public static TokenPair withAccessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken)
{
return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,41 @@
import com.nimbusds.jose.KeyLengthException;
import io.jsonwebtoken.ExpiredJwtException;
import io.jsonwebtoken.Jwts;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.net.URI;
import java.security.GeneralSecurityException;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.time.Clock;
import java.time.Instant;
import java.time.ZoneId;
import java.time.ZonedDateTime;
import java.util.Base64;
import java.util.Calendar;
import java.util.Date;
import java.util.Map;
import java.util.Optional;
import java.util.Random;

import static com.facebook.airlift.units.Duration.succinctDuration;
import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.accessAndRefreshTokens;
import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.withAccessAndRefreshTokens;
import static java.time.temporal.ChronoUnit.MILLIS;
import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;

public class TestJweTokenSerializer
{
@Test
public void testSerialization()
throws Exception
{
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS));
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), randomEncodedSecret());

Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
Expand All @@ -56,14 +63,75 @@ public void testSerialization()
assertThat(deserializedTokenPair.getRefreshToken()).isEqualTo(Optional.of("refresh_token"));
}

@Test(dataProvider = "wrongSecretsProvider")
public void testDeserializationWithWrongSecret(String encryptionSecret, String decryptionSecret)
{
assertThatThrownBy(() -> assertRoundTrip(Optional.ofNullable(encryptionSecret), Optional.ofNullable(decryptionSecret)))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Decryption failed")
.hasStackTraceContaining("Tag mismatch!");
}

@DataProvider
public Object[][] wrongSecretsProvider()
{
return new Object[][] {
{randomEncodedSecret(), randomEncodedSecret()},
{randomEncodedSecret(16), randomEncodedSecret(24)},
{null, null}, // This will generate two different secret keys
{null, randomEncodedSecret()},
{randomEncodedSecret(), null}
};
}

@Test
public void testSerializationDeserializationRoundTripWithDifferentKeyLengths()
throws Exception
{
for (int keySize : new int[] {16, 24, 32}) {
String secret = randomEncodedSecret(keySize);
assertRoundTrip(secret, secret);
}
}

@Test
public void testSerializationFailsWithWrongKeySize()
{
for (int wrongKeySize : new int[] {8, 64, 128}) {
String tooShortSecret = randomEncodedSecret(wrongKeySize);
assertThatThrownBy(() -> assertRoundTrip(tooShortSecret, tooShortSecret))
.hasStackTraceContaining("The Key Encryption Key length must be 128 bits (16 bytes), 192 bits (24 bytes) or 256 bits (32 bytes)");
}
}

private void assertRoundTrip(String serializerSecret, String deserializerSecret)
throws Exception
{
assertRoundTrip(Optional.of(serializerSecret), Optional.of(deserializerSecret));
}

private void assertRoundTrip(Optional<String> serializerSecret, Optional<String> deserializerSecret)
throws Exception
{
JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), serializerSecret);
JweTokenSerializer deserializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), deserializerSecret);
Date expiration = new Calendar.Builder().setDate(2023, 6, 22).build().getTime();
TokenPair tokenPair = withAccessAndRefreshTokens(randomEncodedSecret(), expiration, randomEncodedSecret());
TokenPair postSerPair = deserializer.deserialize(serializer.serialize(tokenPair));
assertEquals(tokenPair.getAccessToken(), postSerPair.getAccessToken());
assertEquals(tokenPair.getRefreshToken(), postSerPair.getRefreshToken());
assertEquals(tokenPair.getExpiration(), postSerPair.getExpiration());
}

@Test
public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension()
throws Exception
{
TestingClock clock = new TestingClock();
JweTokenSerializer serializer = tokenSerializer(
clock,
succinctDuration(12, MINUTES));
succinctDuration(12, MINUTES),
randomEncodedSecret());
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));
clock.advanceBy(succinctDuration(10, MINUTES));
Expand All @@ -82,7 +150,8 @@ public void testTokenDeserializationAfterTimeoutAndExpirationExtension()

JweTokenSerializer serializer = tokenSerializer(
clock,
succinctDuration(12, MINUTES));
succinctDuration(12, MINUTES),
randomEncodedSecret());
Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime();
String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token"));

Expand All @@ -104,6 +173,40 @@ private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration
tokenExpiration);
}

private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, String encodedSecretKey)
throws GeneralSecurityException, KeyLengthException
{
return tokenSerializer(clock, tokenExpiration, Optional.of(encodedSecretKey));
}

private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, Optional<String> secretKey)
throws NoSuchAlgorithmException, KeyLengthException
{
RefreshTokensConfig refreshTokensConfig = new RefreshTokensConfig();
secretKey.ifPresent(refreshTokensConfig::setSecretKey);
return new JweTokenSerializer(
refreshTokensConfig,
new Oauth2ClientStub(),
"presto_coordinator_test_version",
"presto_coordinator",
"sub",
clock,
tokenExpiration);
}

private static String randomEncodedSecret()
{
return randomEncodedSecret(24);
}

private static String randomEncodedSecret(int length)
{
Random random = new SecureRandom();
final byte[] buffer = new byte[length];
random.nextBytes(buffer);
return Base64.getEncoder().encodeToString(buffer);
}

static class Oauth2ClientStub
implements OAuth2Client
{
Expand Down
Loading