diff --git a/presto-main/src/main/java/io/prestosql/server/InternalAuthenticationManager.java b/presto-main/src/main/java/io/prestosql/server/InternalAuthenticationManager.java index 7b2dd0bfefa5..e6c95af8d237 100644 --- a/presto-main/src/main/java/io/prestosql/server/InternalAuthenticationManager.java +++ b/presto-main/src/main/java/io/prestosql/server/InternalAuthenticationManager.java @@ -18,10 +18,7 @@ import io.airlift.http.client.Request; import io.airlift.log.Logger; import io.airlift.node.NodeInfo; -import io.jsonwebtoken.Claims; -import io.jsonwebtoken.Jws; import io.jsonwebtoken.JwtException; -import io.jsonwebtoken.JwtParser; import io.jsonwebtoken.Jwts; import io.jsonwebtoken.SignatureAlgorithm; import io.prestosql.server.security.InternalPrincipal; @@ -33,6 +30,7 @@ import java.time.ZonedDateTime; import java.util.Date; import java.util.Optional; +import java.util.function.Function; import java.util.function.Supplier; import static io.airlift.http.client.Request.Builder.fromRequest; @@ -46,7 +44,7 @@ public class InternalAuthenticationManager private static final String PRESTO_INTERNAL_BEARER = "X-Presto-Internal-Bearer"; - private final Optional jwtParser; + private final Optional> jwtParser; private final Optional> jwtGenerator; @Inject @@ -59,7 +57,7 @@ public InternalAuthenticationManager(Optional sharedSecret, String nodeI { if (sharedSecret.isPresent()) { byte[] hmac = Hashing.sha256().hashString(sharedSecret.get(), UTF_8).asBytes(); - this.jwtParser = Optional.of(Jwts.parser().setSigningKey(hmac)); + this.jwtParser = Optional.of(jwt -> parseJwt(hmac, jwt)); this.jwtGenerator = Optional.of(() -> generateJwt(hmac, nodeId)); } else { @@ -77,6 +75,15 @@ private static String generateJwt(byte[] hmac, String nodeId) .compact(); } + private static String parseJwt(byte[] hmac, String jwt) + { + return Jwts.parser() + .setSigningKey(hmac) + .parseClaimsJws(jwt) + .getBody() + .getSubject(); + } + public boolean isInternalRequest(HttpServletRequest request) { return request.getHeader(PRESTO_INTERNAL_BEARER) != null; @@ -91,8 +98,7 @@ public Principal authenticateInternalRequest(HttpServletRequest request) String internalBarer = request.getHeader(PRESTO_INTERNAL_BEARER); try { - Jws claimsJws = jwtParser.get().parseClaimsJws(internalBarer); - String subject = claimsJws.getBody().getSubject(); + String subject = jwtParser.get().apply(internalBarer); return new InternalPrincipal(subject); } catch (JwtException e) {