2121import  com .google .common .hash .HashFunction ;
2222import  com .google .common .util .concurrent .FluentFuture ;
2323import  com .google .common .util .concurrent .ListenableFuture ;
24+ import  com .google .common .util .concurrent .SettableFuture ;
2425import  com .google .inject .Inject ;
25- import  io .airlift .http .client .HttpClient ;
26- import  io .airlift .http .client .Request ;
2726import  io .airlift .log .Logger ;
2827import  io .airlift .units .Duration ;
29- import  io .trino .proxy .ProxyResponseHandler .ProxyResponse ;
3028import  jakarta .annotation .PreDestroy ;
3129import  jakarta .servlet .http .HttpServletRequest ;
3230import  jakarta .ws .rs .DELETE ;
4240import  jakarta .ws .rs .core .Response ;
4341import  jakarta .ws .rs .core .Response .ResponseBuilder ;
4442import  jakarta .ws .rs .core .Response .Status ;
43+ import  jakarta .ws .rs .core .UriBuilder ;
4544import  jakarta .ws .rs .core .UriInfo ;
45+ import  okhttp3 .Call ;
46+ import  okhttp3 .Callback ;
47+ import  okhttp3 .Headers ;
48+ import  okhttp3 .MediaType ;
49+ import  okhttp3 .OkHttpClient ;
50+ import  okhttp3 .Request ;
51+ import  okhttp3 .RequestBody ;
52+ import  okhttp3 .ResponseBody ;
4653
4754import  java .io .ByteArrayOutputStream ;
4855import  java .io .File ;
6067import  static  com .fasterxml .jackson .core .JsonToken .VALUE_STRING ;
6168import  static  com .google .common .hash .Hashing .hmacSha256 ;
6269import  static  com .google .common .net .HttpHeaders .AUTHORIZATION ;
70+ import  static  com .google .common .net .HttpHeaders .CONTENT_TYPE ;
6371import  static  com .google .common .net .HttpHeaders .COOKIE ;
6472import  static  com .google .common .net .HttpHeaders .SET_COOKIE ;
6573import  static  com .google .common .net .HttpHeaders .USER_AGENT ;
6674import  static  com .google .common .util .concurrent .MoreExecutors .directExecutor ;
6775import  static  io .airlift .concurrent .Threads .daemonThreadsNamed ;
68- import  static  io .airlift .http .client .HttpUriBuilder .uriBuilderFrom ;
69- import  static  io .airlift .http .client .Request .Builder .prepareDelete ;
70- import  static  io .airlift .http .client .Request .Builder .prepareGet ;
71- import  static  io .airlift .http .client .Request .Builder .preparePost ;
72- import  static  io .airlift .http .client .StaticBodyGenerator .createStaticBodyGenerator ;
7376import  static  io .airlift .jaxrs .AsyncResponseHandler .bindAsyncResponse ;
7477import  static  io .trino .plugin .base .util .JsonUtils .jsonFactoryBuilder ;
7578import  static  jakarta .ws .rs .core .MediaType .APPLICATION_JSON ;
7679import  static  jakarta .ws .rs .core .MediaType .TEXT_PLAIN_TYPE ;
7780import  static  jakarta .ws .rs .core .Response .Status .BAD_GATEWAY ;
7881import  static  jakarta .ws .rs .core .Response .Status .FORBIDDEN ;
82+ import  static  jakarta .ws .rs .core .Response .Status .NO_CONTENT ;
83+ import  static  jakarta .ws .rs .core .Response .Status .OK ;
7984import  static  jakarta .ws .rs .core .Response .noContent ;
8085import  static  java .lang .String .format ;
8186import  static  java .nio .charset .StandardCharsets .UTF_8 ;
@@ -96,13 +101,15 @@ public class ProxyResource
96101    private  static  final  JsonFactory  JSON_FACTORY  = jsonFactoryBuilder ().disable (CANONICALIZE_FIELD_NAMES ).build ();
97102
98103    private  final  ExecutorService  executor  = newCachedThreadPool (daemonThreadsNamed ("proxy-%s" ));
99-     private  final  HttpClient  httpClient ;
104+     private  final  OkHttpClient  httpClient ;
100105    private  final  JsonWebTokenHandler  jwtHandler ;
101106    private  final  URI  remoteUri ;
102107    private  final  HashFunction  hmac ;
103108
109+     private  static  final  com .google .common .net .MediaType  JSON  = com .google .common .net .MediaType .create ("application" , "json" );
110+ 
104111    @ Inject 
105-     public  ProxyResource (@ ForProxy  HttpClient  httpClient , JsonWebTokenHandler  jwtHandler , ProxyConfig  config )
112+     public  ProxyResource (@ ForProxy  OkHttpClient  httpClient , JsonWebTokenHandler  jwtHandler , ProxyConfig  config )
106113    {
107114        this .httpClient  = requireNonNull (httpClient , "httpClient is null" );
108115        this .jwtHandler  = requireNonNull (jwtHandler , "jwtHandler is null" );
@@ -123,8 +130,9 @@ public void getInfo(
123130            @ Context  HttpServletRequest  servletRequest ,
124131            @ Suspended  AsyncResponse  asyncResponse )
125132    {
126-         Request .Builder  request  = prepareGet ()
127-                 .setUri (uriBuilderFrom (remoteUri ).replacePath ("/v1/info" ).build ());
133+         Request .Builder  request  = new  Request .Builder ()
134+                 .get ()
135+                 .url (UriBuilder .fromUri (remoteUri ).replacePath ("/v1/info" ).build ().toString ());
128136
129137        performRequest (servletRequest , asyncResponse , request , response  ->
130138                responseWithHeaders (Response .ok (response .getBody ()), response ));
@@ -139,9 +147,9 @@ public void postStatement(
139147            @ Context  UriInfo  uriInfo ,
140148            @ Suspended  AsyncResponse  asyncResponse )
141149    {
142-         Request .Builder  request  = preparePost ()
143-                 .setUri ( uriBuilderFrom ( remoteUri ). replacePath ( "/v1/statement" ). build ( ))
144-                 .setBodyGenerator ( createStaticBodyGenerator ( statement ,  UTF_8 ));
150+         Request .Builder  request  = new   Request . Builder ()
151+                 .post ( RequestBody . create ( statement ,  MediaType . parse ( "application/json" ) ))
152+                 .url ( UriBuilder . fromUri ( remoteUri ). replacePath ( "/v1/ statement" ). build (). toString ( ));
145153
146154        performRequest (servletRequest , asyncResponse , request , response  -> buildResponse (uriInfo , response ));
147155    }
@@ -160,7 +168,9 @@ public void getNext(
160168            throw  badRequest (FORBIDDEN , "Failed to validate HMAC of URI" );
161169        }
162170
163-         Request .Builder  request  = prepareGet ().setUri (URI .create (uri ));
171+         Request .Builder  request  = new  Request .Builder ()
172+                 .get ()
173+                 .url (uri );
164174
165175        performRequest (servletRequest , asyncResponse , request , response  -> buildResponse (uriInfo , response ));
166176    }
@@ -178,7 +188,9 @@ public void cancelQuery(
178188            throw  badRequest (FORBIDDEN , "Failed to validate HMAC of URI" );
179189        }
180190
181-         Request .Builder  request  = prepareDelete ().setUri (URI .create (uri ));
191+         Request .Builder  request  = new  Request .Builder ()
192+                 .delete ()
193+                 .url (uri );
182194
183195        performRequest (servletRequest , asyncResponse , request , response  -> responseWithHeaders (noContent (), response ));
184196    }
@@ -204,10 +216,7 @@ else if (name.equalsIgnoreCase(USER_AGENT)) {
204216            }
205217        }
206218
207-         Request  request  = requestBuilder 
208-                 .setPreserveAuthorizationOnRedirect (true )
209-                 .build ();
210- 
219+         Request  request  = requestBuilder .build ();
211220        ListenableFuture <Response > future  = executeHttp (request )
212221                .transform (responseBuilder ::apply , executor )
213222                .catching (ProxyException .class , e  -> handleProxyException (request , e ), directExecutor ());
@@ -243,7 +252,54 @@ private void setupAsyncResponse(AsyncResponse asyncResponse, ListenableFuture<Re
243252
244253    private  FluentFuture <ProxyResponse > executeHttp (Request  request )
245254    {
246-         return  FluentFuture .from (httpClient .executeAsync (request , new  ProxyResponseHandler ()));
255+         SettableFuture <ProxyResponse > future  = SettableFuture .create ();
256+ 
257+         // Enqueue the call and resolve the future 
258+         httpClient .newCall (request ).enqueue (new  Callback () {
259+             @ Override 
260+             public  void  onFailure (Call  call , IOException  e )
261+             {
262+                 future .setException (e ); // Set the exception if the request fails 
263+             }
264+ 
265+             @ Override 
266+             public  void  onResponse (Call  call , okhttp3 .Response  response )
267+             {
268+                 if  (response .code () == NO_CONTENT .getStatusCode ()) {
269+                     future .set (new  ProxyResponse (response .headers (), new  byte [0 ]));
270+                     return ;
271+                 }
272+ 
273+                 if  (response .code () != OK .getStatusCode ()) {
274+                     try  (ResponseBody  body  = response .body ()) {
275+                         future .setException (new  ProxyException (format ("Bad status code from remote Trino server: %s: %s" , response .code (), body .string ())));
276+                         return ;
277+                     }
278+                     catch  (IOException  e ) {
279+                         future .setException (e );
280+                         return ;
281+                     }
282+                 }
283+ 
284+                 String  contentType  = response .header (CONTENT_TYPE );
285+                 if  (contentType  == null ) {
286+                     throw  new  ProxyException ("No Content-Type set in response from remote Trino server" );
287+                 }
288+                 if  (!com .google .common .net .MediaType .parse (contentType ).is (JSON )) {
289+                     throw  new  ProxyException ("Bad Content-Type from remote Trino server:"  + contentType );
290+                 }
291+ 
292+                 try  (ResponseBody  body  = response .body ()) {
293+                     future .set (new  ProxyResponse (response .headers (), body .bytes ()));
294+                     return ;
295+                 }
296+                 catch  (IOException  e ) {
297+                     throw  new  ProxyException ("Failed reading response from remote Trino server" , e );
298+                 }
299+             }
300+         });
301+ 
302+         return  FluentFuture .from (future );
247303    }
248304
249305    private  void  setupBearerToken (HttpServletRequest  servletRequest , Request .Builder  requestBuilder )
@@ -264,7 +320,7 @@ private void setupBearerToken(HttpServletRequest servletRequest, Request.Builder
264320
265321    private  static  <T > T  handleProxyException (Request  request , ProxyException  e )
266322    {
267-         log .warn (e , "Proxy request failed: %s %s" , request .getMethod (), request .getUri ());
323+         log .warn (e , "Proxy request failed: %s %s" , request .method (), request .url ());
268324        throw  badRequest (BAD_GATEWAY , e .getMessage ());
269325    }
270326
@@ -284,10 +340,9 @@ private static boolean isTrinoHeader(String name)
284340
285341    private  static  Response  responseWithHeaders (ResponseBuilder  builder , ProxyResponse  response )
286342    {
287-         response .getHeaders ().forEach ((headerName , value ) -> {
288-             String  name  = headerName .toString ();
343+         response .getHeaders ().names ().forEach (name  -> {
289344            if  (isTrinoHeader (name ) || name .equalsIgnoreCase (SET_COOKIE )) {
290-                 builder .header (name , value );
345+                 builder .header (name , response . getHeaders (). get ( name ) );
291346            }
292347        });
293348        return  builder .build ();
@@ -365,4 +420,26 @@ private static byte[] loadSharedSecret(File file)
365420            throw  new  RuntimeException ("Failed to load shared secret file: "  + file , e );
366421        }
367422    }
423+ 
424+     public  static  class  ProxyResponse 
425+     {
426+         private  final  Headers  headers ;
427+         private  final  byte [] body ;
428+ 
429+         ProxyResponse (Headers  headers , byte [] body )
430+         {
431+             this .headers  = requireNonNull (headers , "headers is null" );
432+             this .body  = requireNonNull (body , "body is null" );
433+         }
434+ 
435+         public  Headers  getHeaders ()
436+         {
437+             return  headers ;
438+         }
439+ 
440+         public  byte [] getBody ()
441+         {
442+             return  body ;
443+         }
444+     }
368445}
0 commit comments