2121import com .google .common .hash .HashFunction ;
2222import com .google .common .util .concurrent .FluentFuture ;
2323import com .google .common .util .concurrent .ListenableFuture ;
24+ import com .google .common .util .concurrent .SettableFuture ;
2425import com .google .inject .Inject ;
25- import io .airlift .http .client .HttpClient ;
26- import io .airlift .http .client .Request ;
2726import io .airlift .log .Logger ;
2827import io .airlift .units .Duration ;
29- import io .trino .proxy .ProxyResponseHandler .ProxyResponse ;
3028import jakarta .annotation .PreDestroy ;
3129import jakarta .servlet .http .HttpServletRequest ;
3230import jakarta .ws .rs .DELETE ;
4240import jakarta .ws .rs .core .Response ;
4341import jakarta .ws .rs .core .Response .ResponseBuilder ;
4442import jakarta .ws .rs .core .Response .Status ;
43+ import jakarta .ws .rs .core .UriBuilder ;
4544import jakarta .ws .rs .core .UriInfo ;
45+ import okhttp3 .Call ;
46+ import okhttp3 .Callback ;
47+ import okhttp3 .Headers ;
48+ import okhttp3 .MediaType ;
49+ import okhttp3 .OkHttpClient ;
50+ import okhttp3 .Request ;
51+ import okhttp3 .RequestBody ;
52+ import okhttp3 .ResponseBody ;
4653
4754import java .io .ByteArrayOutputStream ;
4855import java .io .File ;
6067import static com .fasterxml .jackson .core .JsonToken .VALUE_STRING ;
6168import static com .google .common .hash .Hashing .hmacSha256 ;
6269import static com .google .common .net .HttpHeaders .AUTHORIZATION ;
70+ import static com .google .common .net .HttpHeaders .CONTENT_TYPE ;
6371import static com .google .common .net .HttpHeaders .COOKIE ;
6472import static com .google .common .net .HttpHeaders .SET_COOKIE ;
6573import static com .google .common .net .HttpHeaders .USER_AGENT ;
6674import static com .google .common .util .concurrent .MoreExecutors .directExecutor ;
6775import static io .airlift .concurrent .Threads .daemonThreadsNamed ;
68- import static io .airlift .http .client .HttpUriBuilder .uriBuilderFrom ;
69- import static io .airlift .http .client .Request .Builder .prepareDelete ;
70- import static io .airlift .http .client .Request .Builder .prepareGet ;
71- import static io .airlift .http .client .Request .Builder .preparePost ;
72- import static io .airlift .http .client .StaticBodyGenerator .createStaticBodyGenerator ;
7376import static io .airlift .jaxrs .AsyncResponseHandler .bindAsyncResponse ;
7477import static io .trino .plugin .base .util .JsonUtils .jsonFactoryBuilder ;
7578import static jakarta .ws .rs .core .MediaType .APPLICATION_JSON ;
7679import static jakarta .ws .rs .core .MediaType .TEXT_PLAIN_TYPE ;
7780import static jakarta .ws .rs .core .Response .Status .BAD_GATEWAY ;
7881import static jakarta .ws .rs .core .Response .Status .FORBIDDEN ;
82+ import static jakarta .ws .rs .core .Response .Status .NO_CONTENT ;
83+ import static jakarta .ws .rs .core .Response .Status .OK ;
7984import static jakarta .ws .rs .core .Response .noContent ;
8085import static java .lang .String .format ;
8186import static java .nio .charset .StandardCharsets .UTF_8 ;
@@ -96,13 +101,15 @@ public class ProxyResource
96101 private static final JsonFactory JSON_FACTORY = jsonFactoryBuilder ().disable (CANONICALIZE_FIELD_NAMES ).build ();
97102
98103 private final ExecutorService executor = newCachedThreadPool (daemonThreadsNamed ("proxy-%s" ));
99- private final HttpClient httpClient ;
104+ private final OkHttpClient httpClient ;
100105 private final JsonWebTokenHandler jwtHandler ;
101106 private final URI remoteUri ;
102107 private final HashFunction hmac ;
103108
109+ private static final com .google .common .net .MediaType JSON = com .google .common .net .MediaType .create ("application" , "json" );
110+
104111 @ Inject
105- public ProxyResource (@ ForProxy HttpClient httpClient , JsonWebTokenHandler jwtHandler , ProxyConfig config )
112+ public ProxyResource (@ ForProxy OkHttpClient httpClient , JsonWebTokenHandler jwtHandler , ProxyConfig config )
106113 {
107114 this .httpClient = requireNonNull (httpClient , "httpClient is null" );
108115 this .jwtHandler = requireNonNull (jwtHandler , "jwtHandler is null" );
@@ -123,8 +130,9 @@ public void getInfo(
123130 @ Context HttpServletRequest servletRequest ,
124131 @ Suspended AsyncResponse asyncResponse )
125132 {
126- Request .Builder request = prepareGet ()
127- .setUri (uriBuilderFrom (remoteUri ).replacePath ("/v1/info" ).build ());
133+ Request .Builder request = new Request .Builder ()
134+ .get ()
135+ .url (UriBuilder .fromUri (remoteUri ).replacePath ("/v1/info" ).build ().toString ());
128136
129137 performRequest (servletRequest , asyncResponse , request , response ->
130138 responseWithHeaders (Response .ok (response .getBody ()), response ));
@@ -139,9 +147,9 @@ public void postStatement(
139147 @ Context UriInfo uriInfo ,
140148 @ Suspended AsyncResponse asyncResponse )
141149 {
142- Request .Builder request = preparePost ()
143- .setUri ( uriBuilderFrom ( remoteUri ). replacePath ( "/v1/statement" ). build ( ))
144- .setBodyGenerator ( createStaticBodyGenerator ( statement , UTF_8 ));
150+ Request .Builder request = new Request . Builder ()
151+ .post ( RequestBody . create ( statement , MediaType . parse ( "application/json" ) ))
152+ .url ( UriBuilder . fromUri ( remoteUri ). replacePath ( "/v1/ statement" ). build (). toString ( ));
145153
146154 performRequest (servletRequest , asyncResponse , request , response -> buildResponse (uriInfo , response ));
147155 }
@@ -160,7 +168,9 @@ public void getNext(
160168 throw badRequest (FORBIDDEN , "Failed to validate HMAC of URI" );
161169 }
162170
163- Request .Builder request = prepareGet ().setUri (URI .create (uri ));
171+ Request .Builder request = new Request .Builder ()
172+ .get ()
173+ .url (uri );
164174
165175 performRequest (servletRequest , asyncResponse , request , response -> buildResponse (uriInfo , response ));
166176 }
@@ -178,7 +188,9 @@ public void cancelQuery(
178188 throw badRequest (FORBIDDEN , "Failed to validate HMAC of URI" );
179189 }
180190
181- Request .Builder request = prepareDelete ().setUri (URI .create (uri ));
191+ Request .Builder request = new Request .Builder ()
192+ .delete ()
193+ .url (uri );
182194
183195 performRequest (servletRequest , asyncResponse , request , response -> responseWithHeaders (noContent (), response ));
184196 }
@@ -204,10 +216,7 @@ else if (name.equalsIgnoreCase(USER_AGENT)) {
204216 }
205217 }
206218
207- Request request = requestBuilder
208- .setPreserveAuthorizationOnRedirect (true )
209- .build ();
210-
219+ Request request = requestBuilder .build ();
211220 ListenableFuture <Response > future = executeHttp (request )
212221 .transform (responseBuilder ::apply , executor )
213222 .catching (ProxyException .class , e -> handleProxyException (request , e ), directExecutor ());
@@ -243,7 +252,54 @@ private void setupAsyncResponse(AsyncResponse asyncResponse, ListenableFuture<Re
243252
244253 private FluentFuture <ProxyResponse > executeHttp (Request request )
245254 {
246- return FluentFuture .from (httpClient .executeAsync (request , new ProxyResponseHandler ()));
255+ SettableFuture <ProxyResponse > future = SettableFuture .create ();
256+
257+ // Enqueue the call and resolve the future
258+ httpClient .newCall (request ).enqueue (new Callback () {
259+ @ Override
260+ public void onFailure (Call call , IOException e )
261+ {
262+ future .setException (e ); // Set the exception if the request fails
263+ }
264+
265+ @ Override
266+ public void onResponse (Call call , okhttp3 .Response response )
267+ {
268+ if (response .code () == NO_CONTENT .getStatusCode ()) {
269+ future .set (new ProxyResponse (response .headers (), new byte [0 ]));
270+ return ;
271+ }
272+
273+ if (response .code () != OK .getStatusCode ()) {
274+ try (ResponseBody body = response .body ()) {
275+ future .setException (new ProxyException (format ("Bad status code from remote Trino server: %s: %s" , response .code (), body .string ())));
276+ return ;
277+ }
278+ catch (IOException e ) {
279+ future .setException (e );
280+ return ;
281+ }
282+ }
283+
284+ String contentType = response .header (CONTENT_TYPE );
285+ if (contentType == null ) {
286+ throw new ProxyException ("No Content-Type set in response from remote Trino server" );
287+ }
288+ if (!com .google .common .net .MediaType .parse (contentType ).is (JSON )) {
289+ throw new ProxyException ("Bad Content-Type from remote Trino server:" + contentType );
290+ }
291+
292+ try (ResponseBody body = response .body ()) {
293+ future .set (new ProxyResponse (response .headers (), body .bytes ()));
294+ return ;
295+ }
296+ catch (IOException e ) {
297+ throw new ProxyException ("Failed reading response from remote Trino server" , e );
298+ }
299+ }
300+ });
301+
302+ return FluentFuture .from (future );
247303 }
248304
249305 private void setupBearerToken (HttpServletRequest servletRequest , Request .Builder requestBuilder )
@@ -264,7 +320,7 @@ private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder
264320
265321 private static <T > T handleProxyException (Request request , ProxyException e )
266322 {
267- log .warn (e , "Proxy request failed: %s %s" , request .getMethod (), request .getUri ());
323+ log .warn (e , "Proxy request failed: %s %s" , request .method (), request .url ());
268324 throw badRequest (BAD_GATEWAY , e .getMessage ());
269325 }
270326
@@ -284,10 +340,9 @@ private static boolean isTrinoHeader(String name)
284340
285341 private static Response responseWithHeaders (ResponseBuilder builder , ProxyResponse response )
286342 {
287- response .getHeaders ().forEach ((headerName , value ) -> {
288- String name = headerName .toString ();
343+ response .getHeaders ().names ().forEach (name -> {
289344 if (isTrinoHeader (name ) || name .equalsIgnoreCase (SET_COOKIE )) {
290- builder .header (name , value );
345+ builder .header (name , response . getHeaders (). get ( name ) );
291346 }
292347 });
293348 return builder .build ();
@@ -365,4 +420,26 @@ private static byte[] loadSharedSecret(File file)
365420 throw new RuntimeException ("Failed to load shared secret file: " + file , e );
366421 }
367422 }
423+
424+ public static class ProxyResponse
425+ {
426+ private final Headers headers ;
427+ private final byte [] body ;
428+
429+ ProxyResponse (Headers headers , byte [] body )
430+ {
431+ this .headers = requireNonNull (headers , "headers is null" );
432+ this .body = requireNonNull (body , "body is null" );
433+ }
434+
435+ public Headers getHeaders ()
436+ {
437+ return headers ;
438+ }
439+
440+ public byte [] getBody ()
441+ {
442+ return body ;
443+ }
444+ }
368445}
0 commit comments