@@ -368,11 +368,24 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t
368368// attempted. If overriding this, be sure to close the body if needed.
369369type ErrorHandler func (resp * http.Response , err error , numTries int ) (* http.Response , error )
370370
371+ type HTTPClient interface {
372+ // Do performs an HTTP request and returns an HTTP response.
373+ Do (* http.Request ) (* http.Response , error )
374+ // Done is called when the client is no longer needed.
375+ Done ()
376+ }
377+
378+ type HTTPClientFactory interface {
379+ // New returns an HTTP client to use for a request, including retries.
380+ New () HTTPClient
381+ }
382+
371383// Client is used to make HTTP requests. It adds additional functionality
372384// like automatic retries to tolerate minor outages.
373385type Client struct {
374- HTTPClient * http.Client // Internal HTTP client.
375- Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
386+ HTTPClient * http.Client // Internal HTTP client. This field is used if set, otherwise HTTPClientFactory is used.
387+ HTTPClientFactory HTTPClientFactory
388+ Logger interface {} // Customer logger instance. Can be either Logger or LeveledLogger
376389
377390 RetryWaitMin time.Duration // Minimum time to wait
378391 RetryWaitMax time.Duration // Maximum time to wait
@@ -397,19 +410,18 @@ type Client struct {
397410 ErrorHandler ErrorHandler
398411
399412 loggerInit sync.Once
400- clientInit sync.Once
401413}
402414
403415// NewClient creates a new Client with default settings.
404416func NewClient () * Client {
405417 return & Client {
406- HTTPClient : cleanhttp . DefaultPooledClient () ,
407- Logger : defaultLogger ,
408- RetryWaitMin : defaultRetryWaitMin ,
409- RetryWaitMax : defaultRetryWaitMax ,
410- RetryMax : defaultRetryMax ,
411- CheckRetry : DefaultRetryPolicy ,
412- Backoff : DefaultBackoff ,
418+ HTTPClientFactory : & CleanPooledClientFactory {} ,
419+ Logger : defaultLogger ,
420+ RetryWaitMin : defaultRetryWaitMin ,
421+ RetryWaitMax : defaultRetryWaitMax ,
422+ RetryMax : defaultRetryMax ,
423+ CheckRetry : DefaultRetryPolicy ,
424+ Backoff : DefaultBackoff ,
413425 }
414426}
415427
@@ -573,12 +585,6 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo
573585
574586// Do wraps calling an HTTP method with retries.
575587func (c * Client ) Do (req * Request ) (* http.Response , error ) {
576- c .clientInit .Do (func () {
577- if c .HTTPClient == nil {
578- c .HTTPClient = cleanhttp .DefaultPooledClient ()
579- }
580- })
581-
582588 logger := c .logger ()
583589
584590 if logger != nil {
@@ -590,6 +596,9 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
590596 }
591597 }
592598
599+ httpClient := c .getHTTPClient ()
600+ defer httpClient .Done ()
601+
593602 var resp * http.Response
594603 var attempt int
595604 var shouldRetry bool
@@ -603,7 +612,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
603612 if req .body != nil {
604613 body , err := req .body ()
605614 if err != nil {
606- c .HTTPClient .CloseIdleConnections ()
607615 return resp , err
608616 }
609617 if c , ok := body .(io.ReadCloser ); ok {
@@ -625,7 +633,8 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
625633 }
626634
627635 // Attempt the request
628- resp , doErr = c .HTTPClient .Do (req .Request )
636+
637+ resp , doErr = httpClient .Do (req .Request )
629638
630639 // Check if we should continue with retries.
631640 shouldRetry , checkErr = c .CheckRetry (req .Context (), resp , doErr )
@@ -694,7 +703,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
694703 select {
695704 case <- req .Context ().Done ():
696705 timer .Stop ()
697- c .HTTPClient .CloseIdleConnections ()
698706 return nil , req .Context ().Err ()
699707 case <- timer .C :
700708 }
@@ -710,8 +718,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
710718 return resp , nil
711719 }
712720
713- defer c .HTTPClient .CloseIdleConnections ()
714-
715721 var err error
716722 if checkErr != nil {
717723 err = checkErr
@@ -758,6 +764,19 @@ func (c *Client) drainBody(body io.ReadCloser) {
758764 }
759765}
760766
767+ func (c * Client ) getHTTPClient () HTTPClient {
768+ if c .HTTPClient != nil {
769+ return & idleConnectionsClosingClient {
770+ httpClient : c .HTTPClient ,
771+ }
772+ }
773+ clientFactory := c .HTTPClientFactory
774+ if clientFactory == nil {
775+ clientFactory = & CleanPooledClientFactory {}
776+ }
777+ return clientFactory .New ()
778+ }
779+
761780// Get is a shortcut for doing a GET request without making a new client.
762781func Get (url string ) (* http.Response , error ) {
763782 return defaultClient .Get (url )
@@ -820,3 +839,29 @@ func (c *Client) StandardClient() *http.Client {
820839 Transport : & RoundTripper {Client : c },
821840 }
822841}
842+
843+ var (
844+ _ HTTPClientFactory = & CleanPooledClientFactory {}
845+ _ HTTPClient = & idleConnectionsClosingClient {}
846+ )
847+
848+ type CleanPooledClientFactory struct {
849+ }
850+
851+ func (f * CleanPooledClientFactory ) New () HTTPClient {
852+ return & idleConnectionsClosingClient {
853+ httpClient : cleanhttp .DefaultPooledClient (),
854+ }
855+ }
856+
857+ type idleConnectionsClosingClient struct {
858+ httpClient * http.Client
859+ }
860+
861+ func (c * idleConnectionsClosingClient ) Do (req * http.Request ) (* http.Response , error ) {
862+ return c .httpClient .Do (req )
863+ }
864+
865+ func (c * idleConnectionsClosingClient ) Done () {
866+ c .httpClient .CloseIdleConnections ()
867+ }
0 commit comments