Skip to content

Commit 116b08c

Browse files
committed
Replace airlift's auth preserving client with okhttp
1 parent 3431609 commit 116b08c

File tree

5 files changed

+164
-138
lines changed

5 files changed

+164
-138
lines changed

service/trino-proxy/pom.xml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,23 +33,23 @@
3333
</dependency>
3434

3535
<dependency>
36-
<groupId>io.airlift</groupId>
37-
<artifactId>bootstrap</artifactId>
36+
<groupId>com.squareup.okhttp3</groupId>
37+
<artifactId>okhttp</artifactId>
3838
</dependency>
3939

4040
<dependency>
4141
<groupId>io.airlift</groupId>
42-
<artifactId>concurrent</artifactId>
42+
<artifactId>bootstrap</artifactId>
4343
</dependency>
4444

4545
<dependency>
4646
<groupId>io.airlift</groupId>
47-
<artifactId>configuration</artifactId>
47+
<artifactId>concurrent</artifactId>
4848
</dependency>
4949

5050
<dependency>
5151
<groupId>io.airlift</groupId>
52-
<artifactId>http-client</artifactId>
52+
<artifactId>configuration</artifactId>
5353
</dependency>
5454

5555
<dependency>
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package io.trino.proxy;
15+
16+
import okhttp3.Interceptor;
17+
import okhttp3.Request;
18+
import okhttp3.Response;
19+
20+
import java.io.IOException;
21+
22+
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
23+
24+
public class AuthPreservingInterceptor
25+
implements Interceptor
26+
{
27+
@Override
28+
public Response intercept(Interceptor.Chain chain)
29+
throws IOException
30+
{
31+
Request request = chain.request();
32+
Response response = chain.proceed(request);
33+
34+
if (response.isRedirect()) {
35+
String authHeader = request.header(AUTHORIZATION);
36+
if (authHeader != null) {
37+
Request redirectedRequest = response.request().newBuilder()
38+
.header(AUTHORIZATION, authHeader)
39+
.build();
40+
return chain.proceed(redirectedRequest);
41+
}
42+
}
43+
return response;
44+
}
45+
}

service/trino-proxy/src/main/java/io/trino/proxy/ProxyModule.java

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515

1616
import com.google.inject.Binder;
1717
import com.google.inject.Module;
18+
import com.google.inject.Provides;
1819
import com.google.inject.Scopes;
20+
import okhttp3.OkHttpClient;
1921

2022
import static io.airlift.configuration.ConfigBinder.configBinder;
21-
import static io.airlift.http.client.HttpClientBinder.httpClientBinder;
2223
import static io.airlift.jaxrs.JaxrsBinder.jaxrsBinder;
2324

2425
public class ProxyModule
@@ -27,13 +28,19 @@ public class ProxyModule
2728
@Override
2829
public void configure(Binder binder)
2930
{
30-
httpClientBinder(binder).bindHttpClient("proxy", ForProxy.class);
31-
3231
configBinder(binder).bindConfig(ProxyConfig.class);
3332
configBinder(binder).bindConfig(JwtHandlerConfig.class, "proxy");
34-
3533
jaxrsBinder(binder).bind(ProxyResource.class);
3634

3735
binder.bind(JsonWebTokenHandler.class).in(Scopes.SINGLETON);
3836
}
37+
38+
@ForProxy
39+
@Provides
40+
private OkHttpClient createHttpClient()
41+
{
42+
return new OkHttpClient.Builder()
43+
.addInterceptor(new AuthPreservingInterceptor())
44+
.build();
45+
}
3946
}

service/trino-proxy/src/main/java/io/trino/proxy/ProxyResource.java

Lines changed: 103 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,10 @@
2121
import com.google.common.hash.HashFunction;
2222
import com.google.common.util.concurrent.FluentFuture;
2323
import com.google.common.util.concurrent.ListenableFuture;
24+
import com.google.common.util.concurrent.SettableFuture;
2425
import com.google.inject.Inject;
25-
import io.airlift.http.client.HttpClient;
26-
import io.airlift.http.client.Request;
2726
import io.airlift.log.Logger;
2827
import io.airlift.units.Duration;
29-
import io.trino.proxy.ProxyResponseHandler.ProxyResponse;
3028
import jakarta.annotation.PreDestroy;
3129
import jakarta.servlet.http.HttpServletRequest;
3230
import jakarta.ws.rs.DELETE;
@@ -42,7 +40,16 @@
4240
import jakarta.ws.rs.core.Response;
4341
import jakarta.ws.rs.core.Response.ResponseBuilder;
4442
import jakarta.ws.rs.core.Response.Status;
43+
import jakarta.ws.rs.core.UriBuilder;
4544
import jakarta.ws.rs.core.UriInfo;
45+
import okhttp3.Call;
46+
import okhttp3.Callback;
47+
import okhttp3.Headers;
48+
import okhttp3.MediaType;
49+
import okhttp3.OkHttpClient;
50+
import okhttp3.Request;
51+
import okhttp3.RequestBody;
52+
import okhttp3.ResponseBody;
4653

4754
import java.io.ByteArrayOutputStream;
4855
import java.io.File;
@@ -60,22 +67,20 @@
6067
import static com.fasterxml.jackson.core.JsonToken.VALUE_STRING;
6168
import static com.google.common.hash.Hashing.hmacSha256;
6269
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
70+
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
6371
import static com.google.common.net.HttpHeaders.COOKIE;
6472
import static com.google.common.net.HttpHeaders.SET_COOKIE;
6573
import static com.google.common.net.HttpHeaders.USER_AGENT;
6674
import static com.google.common.util.concurrent.MoreExecutors.directExecutor;
6775
import static io.airlift.concurrent.Threads.daemonThreadsNamed;
68-
import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom;
69-
import static io.airlift.http.client.Request.Builder.prepareDelete;
70-
import static io.airlift.http.client.Request.Builder.prepareGet;
71-
import static io.airlift.http.client.Request.Builder.preparePost;
72-
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
7376
import static io.airlift.jaxrs.AsyncResponseHandler.bindAsyncResponse;
7477
import static io.trino.plugin.base.util.JsonUtils.jsonFactoryBuilder;
7578
import static jakarta.ws.rs.core.MediaType.APPLICATION_JSON;
7679
import static jakarta.ws.rs.core.MediaType.TEXT_PLAIN_TYPE;
7780
import static jakarta.ws.rs.core.Response.Status.BAD_GATEWAY;
7881
import static jakarta.ws.rs.core.Response.Status.FORBIDDEN;
82+
import static jakarta.ws.rs.core.Response.Status.NO_CONTENT;
83+
import static jakarta.ws.rs.core.Response.Status.OK;
7984
import static jakarta.ws.rs.core.Response.noContent;
8085
import static java.lang.String.format;
8186
import static java.nio.charset.StandardCharsets.UTF_8;
@@ -96,13 +101,15 @@ public class ProxyResource
96101
private static final JsonFactory JSON_FACTORY = jsonFactoryBuilder().disable(CANONICALIZE_FIELD_NAMES).build();
97102

98103
private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("proxy-%s"));
99-
private final HttpClient httpClient;
104+
private final OkHttpClient httpClient;
100105
private final JsonWebTokenHandler jwtHandler;
101106
private final URI remoteUri;
102107
private final HashFunction hmac;
103108

109+
private static final com.google.common.net.MediaType JSON = com.google.common.net.MediaType.create("application", "json");
110+
104111
@Inject
105-
public ProxyResource(@ForProxy HttpClient httpClient, JsonWebTokenHandler jwtHandler, ProxyConfig config)
112+
public ProxyResource(@ForProxy OkHttpClient httpClient, JsonWebTokenHandler jwtHandler, ProxyConfig config)
106113
{
107114
this.httpClient = requireNonNull(httpClient, "httpClient is null");
108115
this.jwtHandler = requireNonNull(jwtHandler, "jwtHandler is null");
@@ -123,8 +130,9 @@ public void getInfo(
123130
@Context HttpServletRequest servletRequest,
124131
@Suspended AsyncResponse asyncResponse)
125132
{
126-
Request.Builder request = prepareGet()
127-
.setUri(uriBuilderFrom(remoteUri).replacePath("/v1/info").build());
133+
Request.Builder request = new Request.Builder()
134+
.get()
135+
.url(UriBuilder.fromUri(remoteUri).replacePath("/v1/info").build().toString());
128136

129137
performRequest(servletRequest, asyncResponse, request, response ->
130138
responseWithHeaders(Response.ok(response.getBody()), response));
@@ -139,9 +147,9 @@ public void postStatement(
139147
@Context UriInfo uriInfo,
140148
@Suspended AsyncResponse asyncResponse)
141149
{
142-
Request.Builder request = preparePost()
143-
.setUri(uriBuilderFrom(remoteUri).replacePath("/v1/statement").build())
144-
.setBodyGenerator(createStaticBodyGenerator(statement, UTF_8));
150+
Request.Builder request = new Request.Builder()
151+
.post(RequestBody.create(statement, MediaType.parse("application/json")))
152+
.url(UriBuilder.fromUri(remoteUri).replacePath("/v1/statement").build().toString());
145153

146154
performRequest(servletRequest, asyncResponse, request, response -> buildResponse(uriInfo, response));
147155
}
@@ -160,7 +168,9 @@ public void getNext(
160168
throw badRequest(FORBIDDEN, "Failed to validate HMAC of URI");
161169
}
162170

163-
Request.Builder request = prepareGet().setUri(URI.create(uri));
171+
Request.Builder request = new Request.Builder()
172+
.get()
173+
.url(uri);
164174

165175
performRequest(servletRequest, asyncResponse, request, response -> buildResponse(uriInfo, response));
166176
}
@@ -178,7 +188,9 @@ public void cancelQuery(
178188
throw badRequest(FORBIDDEN, "Failed to validate HMAC of URI");
179189
}
180190

181-
Request.Builder request = prepareDelete().setUri(URI.create(uri));
191+
Request.Builder request = new Request.Builder()
192+
.delete()
193+
.url(uri);
182194

183195
performRequest(servletRequest, asyncResponse, request, response -> responseWithHeaders(noContent(), response));
184196
}
@@ -204,10 +216,7 @@ else if (name.equalsIgnoreCase(USER_AGENT)) {
204216
}
205217
}
206218

207-
Request request = requestBuilder
208-
.setPreserveAuthorizationOnRedirect(true)
209-
.build();
210-
219+
Request request = requestBuilder.build();
211220
ListenableFuture<Response> future = executeHttp(request)
212221
.transform(responseBuilder::apply, executor)
213222
.catching(ProxyException.class, e -> handleProxyException(request, e), directExecutor());
@@ -243,7 +252,54 @@ private void setupAsyncResponse(AsyncResponse asyncResponse, ListenableFuture<Re
243252

244253
private FluentFuture<ProxyResponse> executeHttp(Request request)
245254
{
246-
return FluentFuture.from(httpClient.executeAsync(request, new ProxyResponseHandler()));
255+
SettableFuture<ProxyResponse> future = SettableFuture.create();
256+
257+
// Enqueue the call and resolve the future
258+
httpClient.newCall(request).enqueue(new Callback() {
259+
@Override
260+
public void onFailure(Call call, IOException e)
261+
{
262+
future.setException(e); // Set the exception if the request fails
263+
}
264+
265+
@Override
266+
public void onResponse(Call call, okhttp3.Response response)
267+
{
268+
if (response.code() == NO_CONTENT.getStatusCode()) {
269+
future.set(new ProxyResponse(response.headers(), new byte[0]));
270+
return;
271+
}
272+
273+
if (response.code() != OK.getStatusCode()) {
274+
try (ResponseBody body = response.body()) {
275+
future.setException(new ProxyException(format("Bad status code from remote Trino server: %s: %s", response.code(), body.string())));
276+
return;
277+
}
278+
catch (IOException e) {
279+
future.setException(e);
280+
return;
281+
}
282+
}
283+
284+
String contentType = response.header(CONTENT_TYPE);
285+
if (contentType == null) {
286+
throw new ProxyException("No Content-Type set in response from remote Trino server");
287+
}
288+
if (!com.google.common.net.MediaType.parse(contentType).is(JSON)) {
289+
throw new ProxyException("Bad Content-Type from remote Trino server:" + contentType);
290+
}
291+
292+
try (ResponseBody body = response.body()) {
293+
future.set(new ProxyResponse(response.headers(), body.bytes()));
294+
return;
295+
}
296+
catch (IOException e) {
297+
throw new ProxyException("Failed reading response from remote Trino server", e);
298+
}
299+
}
300+
});
301+
302+
return FluentFuture.from(future);
247303
}
248304

249305
private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder requestBuilder)
@@ -264,7 +320,7 @@ private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder
264320

265321
private static <T> T handleProxyException(Request request, ProxyException e)
266322
{
267-
log.warn(e, "Proxy request failed: %s %s", request.getMethod(), request.getUri());
323+
log.warn(e, "Proxy request failed: %s %s", request.method(), request.url());
268324
throw badRequest(BAD_GATEWAY, e.getMessage());
269325
}
270326

@@ -284,10 +340,9 @@ private static boolean isTrinoHeader(String name)
284340

285341
private static Response responseWithHeaders(ResponseBuilder builder, ProxyResponse response)
286342
{
287-
response.getHeaders().forEach((headerName, value) -> {
288-
String name = headerName.toString();
343+
response.getHeaders().names().forEach(name -> {
289344
if (isTrinoHeader(name) || name.equalsIgnoreCase(SET_COOKIE)) {
290-
builder.header(name, value);
345+
builder.header(name, response.getHeaders().get(name));
291346
}
292347
});
293348
return builder.build();
@@ -365,4 +420,26 @@ private static byte[] loadSharedSecret(File file)
365420
throw new RuntimeException("Failed to load shared secret file: " + file, e);
366421
}
367422
}
423+
424+
public static class ProxyResponse
425+
{
426+
private final Headers headers;
427+
private final byte[] body;
428+
429+
ProxyResponse(Headers headers, byte[] body)
430+
{
431+
this.headers = requireNonNull(headers, "headers is null");
432+
this.body = requireNonNull(body, "body is null");
433+
}
434+
435+
public Headers getHeaders()
436+
{
437+
return headers;
438+
}
439+
440+
public byte[] getBody()
441+
{
442+
return body;
443+
}
444+
}
368445
}

0 commit comments

Comments
 (0)