@@ -54,10 +54,11 @@ use std::{
5454 io:: { self , Write } ,
5555 sync:: {
5656 atomic:: { AtomicBool , AtomicUsize , Ordering } ,
57- Arc , RwLock ,
57+ Arc ,
5858 } ,
5959 time:: { Duration , Instant } ,
6060} ;
61+ use parking_lot:: RwLock ;
6162use url:: Url ;
6263
6364/// Error that can occur when building a [Transport]
@@ -68,6 +69,9 @@ pub enum BuildError {
6869
6970 /// Certificate error
7071 Cert ( reqwest:: Error ) ,
72+
73+ /// Configuration error
74+ Config ( String ) ,
7175}
7276
7377impl From < io:: Error > for BuildError {
@@ -88,13 +92,15 @@ impl error::Error for BuildError {
8892 match * self {
8993 BuildError :: Io ( ref err) => err. description ( ) ,
9094 BuildError :: Cert ( ref err) => err. description ( ) ,
95+ BuildError :: Config ( ref err) => err. as_str ( ) ,
9196 }
9297 }
9398
9499 fn cause ( & self ) -> Option < & dyn error:: Error > {
95100 match * self {
96101 BuildError :: Io ( ref err) => Some ( err as & dyn error:: Error ) ,
97102 BuildError :: Cert ( ref err) => Some ( err as & dyn error:: Error ) ,
103+ BuildError :: Config ( _) => None ,
98104 }
99105 }
100106}
@@ -104,6 +110,7 @@ impl fmt::Display for BuildError {
104110 match * self {
105111 BuildError :: Io ( ref err) => fmt:: Display :: fmt ( err, f) ,
106112 BuildError :: Cert ( ref err) => fmt:: Display :: fmt ( err, f) ,
113+ BuildError :: Config ( ref err) => fmt:: Display :: fmt ( err, f) ,
107114 }
108115 }
109116}
@@ -337,7 +344,7 @@ impl TransportBuilder {
337344 if let Some ( c) = self . proxy_credentials {
338345 proxy = match c {
339346 Credentials :: Basic ( u, p) => proxy. basic_auth ( & u, & p) ,
340- _ => proxy ,
347+ _ => return Err ( BuildError :: Config ( "Only Basic Authentication is supported for proxies" . into ( ) ) ) ,
341348 } ;
342349 }
343350 client_builder = client_builder. proxy ( proxy) ;
@@ -348,7 +355,7 @@ impl TransportBuilder {
348355 client,
349356 conn_pool : self . conn_pool ,
350357 request_body_compression : self . request_body_compression ,
351- credentials : self . credentials ,
358+ credentials : Arc :: new ( RwLock :: new ( self . credentials ) ) ,
352359 send_meta : self . meta_header ,
353360 } )
354361 }
@@ -393,7 +400,7 @@ impl Connection {
393400#[ derive( Debug , Clone ) ]
394401pub struct Transport {
395402 client : reqwest:: Client ,
396- credentials : Option < Credentials > ,
403+ credentials : Arc < RwLock < Option < Credentials > > > ,
397404 request_body_compression : bool ,
398405 conn_pool : Arc < dyn ConnectionPool > ,
399406 send_meta : bool ,
@@ -478,7 +485,7 @@ impl Transport {
478485 /// [Elasticsearch service in Elastic Cloud](https://www.elastic.co/cloud/).
479486 ///
480487 /// * `cloud_id`: The Elastic Cloud Id retrieved from the cloud web console, that uniquely
481- /// identifies the deployment instance.
488+ /// identifies the deployment instance.
482489 /// * `credentials`: A set of credentials the client should use to authenticate to Elasticsearch service.
483490 pub fn cloud ( cloud_id : & str , credentials : Credentials ) -> Result < Transport , Error > {
484491 let conn_pool = CloudConnectionPool :: new ( cloud_id) ?;
@@ -513,7 +520,8 @@ impl Transport {
513520 // set credentials before any headers, as credentials append to existing headers in reqwest,
514521 // whilst setting headers() overwrites, so if an Authorization header has been specified
515522 // on a specific request, we want it to overwrite.
516- if let Some ( c) = & self . credentials {
523+ let creds_guard = self . credentials . read ( ) ;
524+ if let Some ( c) = creds_guard. as_ref ( ) {
517525 request_builder = match c {
518526 Credentials :: Basic ( u, p) => request_builder. basic_auth ( u, Some ( p) ) ,
519527 Credentials :: Bearer ( t) => request_builder. bearer_auth ( t) ,
@@ -537,6 +545,7 @@ impl Transport {
537545 }
538546 }
539547 }
548+ drop ( creds_guard) ;
540549
541550 // default headers first, overwrite with any provided
542551 let mut request_headers = HeaderMap :: with_capacity ( 4 + headers. len ( ) ) ;
@@ -696,6 +705,12 @@ impl Transport {
696705 Err ( e) => Err ( e. into ( ) ) ,
697706 }
698707 }
708+
709+ /// Update the auth credentials for this transport and all its clones, and all clients
710+ /// using them. Typically used to refresh a bearer token.
711+ pub fn set_auth ( & self , credentials : Credentials ) {
712+ * self . credentials . write ( ) = Some ( credentials) ;
713+ }
699714}
700715
701716impl Default for Transport {
@@ -895,14 +910,14 @@ where
895910 ConnSelector : ConnectionSelector + Clone ,
896911{
897912 fn next ( & self ) -> Connection {
898- let inner = self . inner . read ( ) . expect ( "lock poisoned" ) ;
913+ let inner = self . inner . read ( ) ;
899914 self . connection_selector
900915 . try_next ( & inner. connections )
901916 . unwrap ( )
902917 }
903918
904919 fn reseedable ( & self ) -> bool {
905- let inner = self . inner . read ( ) . expect ( "lock poisoned" ) ;
920+ let inner = self . inner . read ( ) ;
906921 let reseed_frequency = match self . reseed_frequency {
907922 Some ( wait) => wait,
908923 None => return false ,
@@ -928,10 +943,11 @@ where
928943 }
929944
930945 fn reseed ( & self , mut connection : Vec < Connection > ) {
931- let mut inner = self . inner . write ( ) . expect ( "lock poisoned" ) ;
946+ let mut inner = self . inner . write ( ) ;
932947 inner. last_update = Some ( Instant :: now ( ) ) ;
933948 inner. connections . clear ( ) ;
934949 inner. connections . append ( & mut connection) ;
950+ drop ( inner) ;
935951 self . reseeding . store ( false , Ordering :: Relaxed ) ;
936952 }
937953}
@@ -1210,7 +1226,7 @@ pub mod tests {
12101226 ) ;
12111227
12121228 // Set internal last_update to a minute ago
1213- let mut inner = connection_pool. inner . write ( ) . expect ( "lock poisoned" ) ;
1229+ let mut inner = connection_pool. inner . write ( ) ;
12141230 inner. last_update = Some ( Instant :: now ( ) - Duration :: from_secs ( 60 ) ) ;
12151231 drop ( inner) ;
12161232
@@ -1249,4 +1265,37 @@ pub mod tests {
12491265 let connections = MultiNodeConnectionPool :: round_robin ( vec ! [ ] , None ) ;
12501266 connections. next ( ) ;
12511267 }
1268+
1269+ #[ test]
1270+ fn set_credentials ( ) -> anyhow:: Result < ( ) > {
1271+ let t1: Transport = TransportBuilder :: new ( SingleNodeConnectionPool :: default ( ) )
1272+ . auth ( Credentials :: Basic ( "foo" . to_string ( ) , "bar" . to_string ( ) ) )
1273+ . build ( ) ?;
1274+
1275+ if let Some ( Credentials :: Basic ( login, password) ) = t1. credentials . read ( ) . as_ref ( ) {
1276+ assert_eq ! ( login, "foo" ) ;
1277+ assert_eq ! ( password, "bar" ) ;
1278+ } else {
1279+ panic ! ( "Expected Basic credentials" ) ;
1280+ }
1281+
1282+ let t2 = t1. clone ( ) ;
1283+
1284+ t1. set_auth ( Credentials :: Bearer ( "The bear" . to_string ( ) ) ) ;
1285+
1286+ if let Some ( Credentials :: Bearer ( token) ) = t1. credentials . read ( ) . as_ref ( ) {
1287+ assert_eq ! ( token, "The bear" ) ;
1288+ } else {
1289+ panic ! ( "Expected Bearer credentials" ) ;
1290+ }
1291+
1292+ // Verify that cloned transport also has the same credentials
1293+ if let Some ( Credentials :: Bearer ( token) ) = t2. credentials . read ( ) . as_ref ( ) {
1294+ assert_eq ! ( token, "The bear" ) ;
1295+ } else {
1296+ panic ! ( "Expected Bearer credentials" ) ;
1297+ }
1298+
1299+ Ok ( ( ) )
1300+ }
12521301}
0 commit comments