Skip to content

Commit b91476f

Browse files
s2lomonPraveen2112
authored andcommitted
Add refresh-token support to OAuth2
It adds refresh tokens support to OAuth2. Whenever refresh-token is issued it's beeing wrapped along with access-token into self issued JWT token and send to the client. Whenever token needs to be refreshed, refresh-token is extracted and used to refresh access-token.
1 parent 5a6fb21 commit b91476f

20 files changed

+1036
-32
lines changed

core/trino-main/src/main/java/io/trino/server/security/AbstractBearerAuthenticator.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,10 @@ public Identity authenticate(ContainerRequestContext request, String token)
3838
throws AuthenticationException
3939
{
4040
try {
41-
return createIdentity(token).orElseThrow(() -> needAuthentication(request, "Invalid credentials"));
41+
return createIdentity(token).orElseThrow(() -> needAuthentication(request, Optional.of(token), "Invalid credentials"));
4242
}
4343
catch (JwtException | UserMappingException e) {
44-
throw needAuthentication(request, e.getMessage());
44+
throw needAuthentication(request, Optional.empty(), e.getMessage());
4545
}
4646
catch (RuntimeException e) {
4747
throw new RuntimeException("Authentication error", e);
@@ -53,7 +53,7 @@ public String extractToken(ContainerRequestContext request)
5353
{
5454
List<String> headers = request.getHeaders().get(AUTHORIZATION);
5555
if (headers == null || headers.size() == 0) {
56-
throw needAuthentication(request, null);
56+
throw needAuthentication(request, Optional.empty(), null);
5757
}
5858
if (headers.size() > 1) {
5959
throw new IllegalArgumentException(format("Multiple %s headers detected: %s, where only single %s header is supported", AUTHORIZATION, headers, AUTHORIZATION));
@@ -62,17 +62,17 @@ public String extractToken(ContainerRequestContext request)
6262
String header = headers.get(0);
6363
int space = header.indexOf(' ');
6464
if ((space < 0) || !header.substring(0, space).equalsIgnoreCase("bearer")) {
65-
throw needAuthentication(request, null);
65+
throw needAuthentication(request, Optional.empty(), null);
6666
}
6767
String token = header.substring(space + 1).trim();
6868
if (token.isEmpty()) {
69-
throw needAuthentication(request, null);
69+
throw needAuthentication(request, Optional.empty(), null);
7070
}
7171
return token;
7272
}
7373

7474
protected abstract Optional<Identity> createIdentity(String token)
7575
throws UserMappingException;
7676

77-
protected abstract AuthenticationException needAuthentication(ContainerRequestContext request, String message);
77+
protected abstract AuthenticationException needAuthentication(ContainerRequestContext request, Optional<String> currentToken, String message);
7878
}

core/trino-main/src/main/java/io/trino/server/security/jwt/JwtAuthenticator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ protected Optional<Identity> createIdentity(String token)
7272
}
7373

7474
@Override
75-
protected AuthenticationException needAuthentication(ContainerRequestContext request, String message)
75+
protected AuthenticationException needAuthentication(ContainerRequestContext request, Optional<String> currentToken, String message)
7676
{
7777
return new AuthenticationException(message, "Bearer realm=\"Trino\", token_type=\"JWT\"");
7878
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.server.security.oauth2;
15+
16+
import com.google.inject.BindingAnnotation;
17+
18+
import java.lang.annotation.Retention;
19+
import java.lang.annotation.Target;
20+
21+
import static java.lang.annotation.ElementType.FIELD;
22+
import static java.lang.annotation.ElementType.METHOD;
23+
import static java.lang.annotation.ElementType.PARAMETER;
24+
import static java.lang.annotation.RetentionPolicy.RUNTIME;
25+
26+
@Retention(RUNTIME)
27+
@Target({FIELD, PARAMETER, METHOD})
28+
@BindingAnnotation
29+
public @interface ForRefreshTokens
30+
{
31+
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.server.security.oauth2;
15+
16+
import com.nimbusds.jose.EncryptionMethod;
17+
import com.nimbusds.jose.JOSEException;
18+
import com.nimbusds.jose.JWEAlgorithm;
19+
import com.nimbusds.jose.JWEHeader;
20+
import com.nimbusds.jose.JWEObject;
21+
import com.nimbusds.jose.KeyLengthException;
22+
import com.nimbusds.jose.Payload;
23+
import com.nimbusds.jose.crypto.AESDecrypter;
24+
import com.nimbusds.jose.crypto.AESEncrypter;
25+
import io.airlift.units.Duration;
26+
import io.jsonwebtoken.Claims;
27+
import io.jsonwebtoken.CompressionCodec;
28+
import io.jsonwebtoken.CompressionException;
29+
import io.jsonwebtoken.Header;
30+
import io.jsonwebtoken.JwtBuilder;
31+
import io.jsonwebtoken.JwtParser;
32+
33+
import javax.crypto.KeyGenerator;
34+
import javax.crypto.SecretKey;
35+
36+
import java.security.NoSuchAlgorithmException;
37+
import java.text.ParseException;
38+
import java.time.Clock;
39+
import java.util.Date;
40+
import java.util.Map;
41+
import java.util.Optional;
42+
43+
import static com.google.common.base.Preconditions.checkState;
44+
import static io.trino.server.security.jwt.JwtUtil.newJwtBuilder;
45+
import static io.trino.server.security.jwt.JwtUtil.newJwtParserBuilder;
46+
import static java.lang.String.format;
47+
import static java.util.Objects.requireNonNull;
48+
49+
public class JweTokenSerializer
50+
implements TokenPairSerializer
51+
{
52+
private static final JWEAlgorithm ALGORITHM = JWEAlgorithm.A256KW;
53+
private static final EncryptionMethod ENCRYPTION_METHOD = EncryptionMethod.A256CBC_HS512;
54+
private static final CompressionCodec COMPRESSION_CODEC = new ZstdCodec();
55+
private static final String ACCESS_TOKEN_KEY = "access_token";
56+
private static final String EXPIRATION_TIME_KEY = "expiration_time";
57+
private static final String REFRESH_TOKEN_KEY = "refresh_token";
58+
private final OAuth2Client client;
59+
private final Clock clock;
60+
private final String issuer;
61+
private final String audience;
62+
private final Duration tokenExpiration;
63+
private final JwtParser parser;
64+
private final AESEncrypter jweEncrypter;
65+
private final AESDecrypter jweDecrypter;
66+
private final String principalField;
67+
68+
public JweTokenSerializer(
69+
RefreshTokensConfig config,
70+
OAuth2Client client,
71+
String issuer,
72+
String audience,
73+
String principalField,
74+
Clock clock,
75+
Duration tokenExpiration)
76+
throws KeyLengthException, NoSuchAlgorithmException
77+
{
78+
SecretKey secretKey = createKey(requireNonNull(config, "config is null"));
79+
this.jweEncrypter = new AESEncrypter(secretKey);
80+
this.jweDecrypter = new AESDecrypter(secretKey);
81+
this.client = requireNonNull(client, "client is null");
82+
this.issuer = requireNonNull(issuer, "issuer is null");
83+
this.principalField = requireNonNull(principalField, "principalField is null");
84+
this.audience = requireNonNull(audience, "issuer is null");
85+
this.clock = requireNonNull(clock, "clock is null");
86+
this.tokenExpiration = requireNonNull(tokenExpiration, "tokenExpiration is null");
87+
88+
this.parser = newJwtParserBuilder()
89+
.setClock(() -> Date.from(clock.instant()))
90+
.requireIssuer(this.issuer)
91+
.requireAudience(this.audience)
92+
.setCompressionCodecResolver(JweTokenSerializer::resolveCompressionCodec)
93+
.build();
94+
}
95+
96+
@Override
97+
public TokenPair deserialize(String token)
98+
{
99+
requireNonNull(token, "token is null");
100+
101+
try {
102+
JWEObject jwe = JWEObject.parse(token);
103+
jwe.decrypt(jweDecrypter);
104+
Claims claims = parser.parseClaimsJwt(jwe.getPayload().toString()).getBody();
105+
return TokenPair.accessAndRefreshTokens(
106+
claims.get(ACCESS_TOKEN_KEY, String.class),
107+
claims.get(EXPIRATION_TIME_KEY, Date.class),
108+
claims.get(REFRESH_TOKEN_KEY, String.class));
109+
}
110+
catch (ParseException ex) {
111+
throw new IllegalArgumentException("Malformed jwt token", ex);
112+
}
113+
catch (JOSEException ex) {
114+
throw new IllegalArgumentException("Decryption failed", ex);
115+
}
116+
}
117+
118+
@Override
119+
public String serialize(TokenPair tokenPair)
120+
{
121+
requireNonNull(tokenPair, "tokenPair is null");
122+
123+
Optional<Map<String, Object>> accessTokenClaims = client.getClaims(tokenPair.getAccessToken());
124+
if (accessTokenClaims.isEmpty()) {
125+
throw new IllegalArgumentException("Claims are missing");
126+
}
127+
Map<String, Object> claims = accessTokenClaims.get();
128+
if (!claims.containsKey(principalField)) {
129+
throw new IllegalArgumentException(format("%s field is missing", principalField));
130+
}
131+
JwtBuilder jwt = newJwtBuilder()
132+
.setExpiration(Date.from(clock.instant().plusMillis(tokenExpiration.toMillis())))
133+
.claim(principalField, claims.get(principalField).toString())
134+
.setAudience(audience)
135+
.setIssuer(issuer)
136+
.claim(ACCESS_TOKEN_KEY, tokenPair.getAccessToken())
137+
.claim(EXPIRATION_TIME_KEY, tokenPair.getExpiration())
138+
.claim(REFRESH_TOKEN_KEY, tokenPair.getRefreshToken().orElseThrow(JweTokenSerializer::throwExceptionForNonExistingRefreshToken))
139+
.compressWith(COMPRESSION_CODEC);
140+
141+
try {
142+
JWEObject jwe = new JWEObject(
143+
new JWEHeader(ALGORITHM, ENCRYPTION_METHOD),
144+
new Payload(jwt.compact()));
145+
jwe.encrypt(jweEncrypter);
146+
return jwe.serialize();
147+
}
148+
catch (JOSEException ex) {
149+
throw new IllegalStateException("Encryption failed", ex);
150+
}
151+
}
152+
153+
private static SecretKey createKey(RefreshTokensConfig config)
154+
throws NoSuchAlgorithmException
155+
{
156+
SecretKey signingKey = config.getSecretKey();
157+
if (signingKey == null) {
158+
KeyGenerator generator = KeyGenerator.getInstance("AES");
159+
generator.init(256);
160+
return generator.generateKey();
161+
}
162+
return signingKey;
163+
}
164+
165+
private static RuntimeException throwExceptionForNonExistingRefreshToken()
166+
{
167+
throw new IllegalStateException("Expected refresh token to be present. Please check your identity provider setup, or disable refresh tokens");
168+
}
169+
170+
private static CompressionCodec resolveCompressionCodec(Header header)
171+
throws CompressionException
172+
{
173+
if (header.getCompressionAlgorithm() != null) {
174+
checkState(header.getCompressionAlgorithm().equals(ZstdCodec.CODEC_NAME), "Unknown codec '%s' used for token compression", header.getCompressionAlgorithm());
175+
return COMPRESSION_CODEC;
176+
}
177+
return null;
178+
}
179+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.server.security.oauth2;
15+
16+
import com.google.inject.Binder;
17+
import com.google.inject.Inject;
18+
import com.google.inject.Provides;
19+
import com.google.inject.Singleton;
20+
import com.nimbusds.jose.KeyLengthException;
21+
import io.airlift.configuration.AbstractConfigurationAwareModule;
22+
import io.trino.client.NodeVersion;
23+
24+
import java.security.NoSuchAlgorithmException;
25+
import java.time.Clock;
26+
27+
import static io.airlift.configuration.ConfigBinder.configBinder;
28+
29+
public class JweTokenSerializerModule
30+
extends AbstractConfigurationAwareModule
31+
{
32+
@Override
33+
protected void setup(Binder binder)
34+
{
35+
configBinder(binder).bindConfig(RefreshTokensConfig.class);
36+
}
37+
38+
@Provides
39+
@Singleton
40+
@Inject
41+
public TokenPairSerializer getTokenPairSerializer(
42+
OAuth2Client client,
43+
NodeVersion nodeVersion,
44+
RefreshTokensConfig config,
45+
OAuth2Config oAuth2Config)
46+
throws KeyLengthException, NoSuchAlgorithmException
47+
{
48+
return new JweTokenSerializer(
49+
config,
50+
client,
51+
config.getIssuer() + "_" + nodeVersion.getVersion(),
52+
config.getAudience(),
53+
oAuth2Config.getPrincipalField(),
54+
Clock.systemUTC(),
55+
config.getTokenExpiration());
56+
}
57+
}

0 commit comments

Comments
 (0)