diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java index b10ea12a8e8a2..45ded8e0ed84c 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/JweTokenSerializer.java @@ -17,6 +17,8 @@ import com.nimbusds.jose.EncryptionMethod; import com.nimbusds.jose.JOSEException; import com.nimbusds.jose.JWEAlgorithm; +import com.nimbusds.jose.JWEDecrypter; +import com.nimbusds.jose.JWEEncrypter; import com.nimbusds.jose.JWEHeader; import com.nimbusds.jose.JWEObject; import com.nimbusds.jose.KeyLengthException; @@ -49,8 +51,8 @@ public class JweTokenSerializer implements TokenPairSerializer { - private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW; - private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512; + private final JWEHeader encryptionHeader; + private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec(); private static final String ACCESS_TOKEN_KEY = "access_token"; private static final String EXPIRATION_TIME_KEY = "expiration_time"; @@ -61,8 +63,8 @@ public class JweTokenSerializer private final String audience; private final Duration tokenExpiration; private final JwtParser parser; - private final AESEncrypter jweEncrypter; - private final AESDecrypter jweDecrypter; + private final JWEEncrypter jweEncrypter; + private final JWEDecrypter jweDecrypter; private final String principalField; public JweTokenSerializer( @@ -84,6 +86,7 @@ public JweTokenSerializer( this.audience = requireNonNull(audience, "issuer is null"); this.clock = requireNonNull(clock, "clock is null"); this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null"); + this.encryptionHeader = createEncryptionHeader(secretKey); this.parser = newJwtParserBuilder() .setClock(() -> Date.from(clock.instant())) @@ -93,11 +96,26 @@ public JweTokenSerializer( .build(); } + private JWEHeader createEncryptionHeader(SecretKey key) + { + int keyLength = key.getEncoded().length; + switch (keyLength) { + case 16: + return new JWEHeader(JWEAlgorithm.A128GCMKW, EncryptionMethod.A128GCM); + case 24: + return new JWEHeader(JWEAlgorithm.A192GCMKW, EncryptionMethod.A192GCM); + case 32: + return new JWEHeader(JWEAlgorithm.A256GCMKW, EncryptionMethod.A256GCM); + default: + throw new IllegalArgumentException( + String.format("Secret key size must be either 16, 24 or 32 bytes but was %d", keyLength)); + } + } + @Override public TokenPair deserialize(String token) { requireNonNull(token, "token is null"); - try { JWEObject jwe = JWEObject.parse(token); jwe.decrypt(jweDecrypter); @@ -139,9 +157,7 @@ public String serialize(TokenPair tokenPair) .compressWith(COMPRESSION_CODEC); try { - JWEObject jwe = new JWEObject( - new JWEHeader(ALGORITHM, ENCRYPTION_METHOD), - new Payload(jwt.compact())); + JWEObject jwe = new JWEObject(encryptionHeader, new Payload(jwt.compact())); jwe.encrypt(jweEncrypter); return jwe.serialize(); } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java index 7d25c142f0332..d4e1783c89128 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/RefreshTokensConfig.java @@ -18,10 +18,10 @@ import com.facebook.airlift.configuration.ConfigSecuritySensitive; import com.facebook.airlift.units.Duration; import io.jsonwebtoken.io.Decoders; -import io.jsonwebtoken.security.Keys; import jakarta.validation.constraints.NotEmpty; import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; import static com.google.common.base.Strings.isNullOrEmpty; import static java.util.concurrent.TimeUnit.HOURS; @@ -85,7 +85,7 @@ public RefreshTokensConfig setSecretKey(String key) return this; } - secretKey = Keys.hmacShaKeyFor(Decoders.BASE64.decode(key)); + secretKey = new SecretKeySpec(Decoders.BASE64.decode(key), "AES"); return this; } diff --git a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java index f178a4048adc8..a92e7fdf00e06 100644 --- a/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java +++ b/presto-main/src/main/java/com/facebook/presto/server/security/oauth2/TokenPairSerializer.java @@ -87,5 +87,10 @@ public Optional getRefreshToken() { return refreshToken; } + + public static TokenPair withAccessAndRefreshTokens(String accessToken, Date expiration, @Nullable String refreshToken) + { + return new TokenPair(accessToken, expiration, Optional.ofNullable(refreshToken)); + } } } diff --git a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java index 7a41c0fdaeeea..e29442b667523 100644 --- a/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java +++ b/presto-main/src/test/java/com/facebook/presto/server/security/oauth2/TestJweTokenSerializer.java @@ -18,26 +18,33 @@ import com.nimbusds.jose.KeyLengthException; import io.jsonwebtoken.ExpiredJwtException; import io.jsonwebtoken.Jwts; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.net.URI; import java.security.GeneralSecurityException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; import java.time.Clock; import java.time.Instant; import java.time.ZoneId; import java.time.ZonedDateTime; +import java.util.Base64; import java.util.Calendar; import java.util.Date; import java.util.Map; import java.util.Optional; +import java.util.Random; import static com.facebook.airlift.units.Duration.succinctDuration; import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.accessAndRefreshTokens; +import static com.facebook.presto.server.security.oauth2.TokenPairSerializer.TokenPair.withAccessAndRefreshTokens; import static java.time.temporal.ChronoUnit.MILLIS; import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; public class TestJweTokenSerializer { @@ -45,7 +52,7 @@ public class TestJweTokenSerializer public void testSerialization() throws Exception { - JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS)); + JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), randomEncodedSecret()); Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime(); String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token")); @@ -56,6 +63,66 @@ public void testSerialization() assertThat(deserializedTokenPair.getRefreshToken()).isEqualTo(Optional.of("refresh_token")); } + @Test(dataProvider = "wrongSecretsProvider") + public void testDeserializationWithWrongSecret(String encryptionSecret, String decryptionSecret) + { + assertThatThrownBy(() -> assertRoundTrip(Optional.ofNullable(encryptionSecret), Optional.ofNullable(decryptionSecret))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("Decryption failed") + .hasStackTraceContaining("Tag mismatch!"); + } + + @DataProvider + public Object[][] wrongSecretsProvider() + { + return new Object[][] { + {randomEncodedSecret(), randomEncodedSecret()}, + {randomEncodedSecret(16), randomEncodedSecret(24)}, + {null, null}, // This will generate two different secret keys + {null, randomEncodedSecret()}, + {randomEncodedSecret(), null} + }; + } + + @Test + public void testSerializationDeserializationRoundTripWithDifferentKeyLengths() + throws Exception + { + for (int keySize : new int[] {16, 24, 32}) { + String secret = randomEncodedSecret(keySize); + assertRoundTrip(secret, secret); + } + } + + @Test + public void testSerializationFailsWithWrongKeySize() + { + for (int wrongKeySize : new int[] {8, 64, 128}) { + String tooShortSecret = randomEncodedSecret(wrongKeySize); + assertThatThrownBy(() -> assertRoundTrip(tooShortSecret, tooShortSecret)) + .hasStackTraceContaining("The Key Encryption Key length must be 128 bits (16 bytes), 192 bits (24 bytes) or 256 bits (32 bytes)"); + } + } + + private void assertRoundTrip(String serializerSecret, String deserializerSecret) + throws Exception + { + assertRoundTrip(Optional.of(serializerSecret), Optional.of(deserializerSecret)); + } + + private void assertRoundTrip(Optional serializerSecret, Optional deserializerSecret) + throws Exception + { + JweTokenSerializer serializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), serializerSecret); + JweTokenSerializer deserializer = tokenSerializer(Clock.systemUTC(), succinctDuration(5, SECONDS), deserializerSecret); + Date expiration = new Calendar.Builder().setDate(2023, 6, 22).build().getTime(); + TokenPair tokenPair = withAccessAndRefreshTokens(randomEncodedSecret(), expiration, randomEncodedSecret()); + TokenPair postSerPair = deserializer.deserialize(serializer.serialize(tokenPair)); + assertEquals(tokenPair.getAccessToken(), postSerPair.getAccessToken()); + assertEquals(tokenPair.getRefreshToken(), postSerPair.getRefreshToken()); + assertEquals(tokenPair.getExpiration(), postSerPair.getExpiration()); + } + @Test public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension() throws Exception @@ -63,7 +130,8 @@ public void testTokenDeserializationAfterTimeoutButBeforeExpirationExtension() TestingClock clock = new TestingClock(); JweTokenSerializer serializer = tokenSerializer( clock, - succinctDuration(12, MINUTES)); + succinctDuration(12, MINUTES), + randomEncodedSecret()); Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime(); String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token")); clock.advanceBy(succinctDuration(10, MINUTES)); @@ -82,7 +150,8 @@ public void testTokenDeserializationAfterTimeoutAndExpirationExtension() JweTokenSerializer serializer = tokenSerializer( clock, - succinctDuration(12, MINUTES)); + succinctDuration(12, MINUTES), + randomEncodedSecret()); Date expiration = new Calendar.Builder().setDate(2022, 6, 22).build().getTime(); String serializedTokenPair = serializer.serialize(accessAndRefreshTokens("access_token", expiration, "refresh_token")); @@ -104,6 +173,40 @@ private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration tokenExpiration); } + private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, String encodedSecretKey) + throws GeneralSecurityException, KeyLengthException + { + return tokenSerializer(clock, tokenExpiration, Optional.of(encodedSecretKey)); + } + + private JweTokenSerializer tokenSerializer(Clock clock, Duration tokenExpiration, Optional secretKey) + throws NoSuchAlgorithmException, KeyLengthException + { + RefreshTokensConfig refreshTokensConfig = new RefreshTokensConfig(); + secretKey.ifPresent(refreshTokensConfig::setSecretKey); + return new JweTokenSerializer( + refreshTokensConfig, + new Oauth2ClientStub(), + "presto_coordinator_test_version", + "presto_coordinator", + "sub", + clock, + tokenExpiration); + } + + private static String randomEncodedSecret() + { + return randomEncodedSecret(24); + } + + private static String randomEncodedSecret(int length) + { + Random random = new SecureRandom(); + final byte[] buffer = new byte[length]; + random.nextBytes(buffer); + return Base64.getEncoder().encodeToString(buffer); + } + static class Oauth2ClientStub implements OAuth2Client {