-
Notifications
You must be signed in to change notification settings - Fork 1.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for custom dial function with timeouts #1669
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -185,7 +185,15 @@ type Client struct { | |
|
||
// Callback for establishing new connections to hosts. | ||
// | ||
// Default Dial is used if not set. | ||
// Default DialTimeout is used if not set. | ||
DialTimeout DialFuncWithTimeout | ||
|
||
// Callback for establishing new connections to hosts. | ||
// | ||
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout. | ||
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead. | ||
// | ||
// If not set, DialTimeout is used. | ||
Dial DialFunc | ||
|
||
// Attempt to connect to both ipv4 and ipv6 addresses if set to true. | ||
|
@@ -505,6 +513,7 @@ func (c *Client) Do(req *Request, resp *Response) error { | |
Name: c.Name, | ||
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader, | ||
Dial: c.Dial, | ||
DialTimeout: c.DialTimeout, | ||
DialDualStack: c.DialDualStack, | ||
IsTLS: isTLS, | ||
TLSConfig: c.TLSConfig, | ||
|
@@ -624,6 +633,21 @@ const DefaultMaxIdemponentCallAttempts = 5 | |
// - foobar.com:8080 | ||
type DialFunc func(addr string) (net.Conn, error) | ||
|
||
// DialFuncWithTimeout must establish connection to addr. | ||
// Unlike DialFunc, it also accepts a timeout. | ||
// | ||
// There is no need in establishing TLS (SSL) connection for https. | ||
// The client automatically converts connection to TLS | ||
// if HostClient.IsTLS is set. | ||
// | ||
// TCP address passed to DialFuncWithTimeout always contains host and port. | ||
// Example TCP addr values: | ||
// | ||
// - foobar.com:80 | ||
// - foobar.com:443 | ||
// - foobar.com:8080 | ||
type DialFuncWithTimeout func(addr string, timeout time.Duration) (net.Conn, error) | ||
|
||
// RetryIfFunc signature of retry if function | ||
// | ||
// Request argument passed to RetryIfFunc, if there are any request errors. | ||
|
@@ -656,7 +680,7 @@ type HostClient struct { | |
noCopy noCopy | ||
|
||
// Comma-separated list of upstream HTTP server host addresses, | ||
// which are passed to Dial in a round-robin manner. | ||
// which are passed to Dial or DialTimeout in a round-robin manner. | ||
// | ||
// Each address may contain port if default dialer is used. | ||
// For example, | ||
|
@@ -673,16 +697,24 @@ type HostClient struct { | |
// User-Agent header to be excluded from the Request. | ||
NoDefaultUserAgentHeader bool | ||
|
||
// Callback for establishing new connection to the host. | ||
// Callback for establishing new connections to hosts. | ||
// | ||
// Default Dial is used if not set. | ||
// Default DialTimeout is used if not set. | ||
DialTimeout DialFuncWithTimeout | ||
|
||
// Callback for establishing new connections to hosts. | ||
// | ||
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout. | ||
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead. | ||
// | ||
// If not set, DialTimeout is used. | ||
Dial DialFunc | ||
|
||
// Attempt to connect to both ipv4 and ipv6 host addresses | ||
// if set to true. | ||
// | ||
// This option is used only if default TCP dialer is used, | ||
// i.e. if Dial is blank. | ||
// i.e. if Dial and DialTimeout are blank. | ||
// | ||
// By default client connects only to ipv4 addresses, | ||
// since unfortunately ipv6 remains broken in many networks worldwide :) | ||
|
@@ -1827,7 +1859,8 @@ func (c *HostClient) nextAddr() string { | |
} | ||
|
||
func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err error) { | ||
// use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or dial has been set. | ||
// use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or if | ||
// c.DialTimeout has not been set and c.Dial has been set. | ||
// attempt to dial all the available hosts before giving up. | ||
|
||
c.addrsLock.Lock() | ||
|
@@ -1839,16 +1872,6 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err | |
n = 1 | ||
} | ||
|
||
dial := c.Dial | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this block is removed in favor of the new callDialFunc function |
||
if dialTimeout != 0 && dial == nil { | ||
dial = func(addr string) (net.Conn, error) { | ||
if c.DialDualStack { | ||
return DialDualStackTimeout(addr, dialTimeout) | ||
} | ||
return DialTimeout(addr, dialTimeout) | ||
} | ||
} | ||
|
||
timeout := c.ReadTimeout + c.WriteTimeout | ||
if timeout <= 0 { | ||
timeout = DefaultDialTimeout | ||
|
@@ -1857,7 +1880,7 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err | |
for n > 0 { | ||
addr := c.nextAddr() | ||
tlsConfig := c.cachedTLSConfig(addr) | ||
conn, err = dialAddr(addr, dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) | ||
conn, err = dialAddr(addr, c.Dial, c.DialTimeout, c.DialDualStack, c.IsTLS, tlsConfig, dialTimeout, c.WriteTimeout) | ||
if err == nil { | ||
return conn, nil | ||
} | ||
|
@@ -1916,17 +1939,9 @@ func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.T | |
return conn, nil | ||
} | ||
|
||
func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { | ||
deadline := time.Now().Add(timeout) | ||
if dial == nil { | ||
if dialDualStack { | ||
dial = DialDualStack | ||
} else { | ||
dial = Dial | ||
} | ||
addr = AddMissingPort(addr, isTLS) | ||
} | ||
conn, err := dial(addr) | ||
func dialAddr(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, tlsConfig *tls.Config, dialTimeout, writeTimeout time.Duration) (net.Conn, error) { | ||
deadline := time.Now().Add(writeTimeout) | ||
conn, err := callDialFunc(addr, dial, dialWithTimeout, dialDualStack, isTLS, dialTimeout) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
@@ -1939,14 +1954,34 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig * | |
_, isTLSAlready := conn.(interface{ Handshake() error }) | ||
|
||
if isTLS && !isTLSAlready { | ||
if timeout == 0 { | ||
if writeTimeout == 0 { | ||
return tls.Client(conn, tlsConfig), nil | ||
} | ||
return tlsClientHandshake(conn, tlsConfig, deadline) | ||
} | ||
return conn, nil | ||
} | ||
|
||
func callDialFunc(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, timeout time.Duration) (net.Conn, error) { | ||
if dialWithTimeout != nil { | ||
return dialWithTimeout(addr, timeout) | ||
} | ||
if dial != nil { | ||
return dial(addr) | ||
} | ||
addr = AddMissingPort(addr, isTLS) | ||
if timeout > 0 { | ||
if dialDualStack { | ||
return DialDualStackTimeout(addr, timeout) | ||
} | ||
return DialTimeout(addr, timeout) | ||
} | ||
if dialDualStack { | ||
return DialDualStack(addr) | ||
} | ||
return Dial(addr) | ||
} | ||
|
||
// AddMissingPort adds a port to a host if it is missing. | ||
// A literal IPv6 address in hostport must be enclosed in square | ||
// brackets, as in "[::1]:80", "[::1%lo0]:80". | ||
|
@@ -2591,7 +2626,7 @@ func (c *pipelineConnClient) init() { | |
|
||
func (c *pipelineConnClient) worker() error { | ||
tlsConfig := c.cachedTLSConfig() | ||
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) | ||
conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. to not overcomplicate the commit, i left PipelineClient as is - meaning although it does implement a DoTimeout function, it does not support passing it to the dial process and thus does not expose a DialWithTimeout param. Adding it could be quite complex due to the async nature of PipelineClient, the dial process might serve several different requests so to support it we might need to further design how exactly we want to act when we have several different timeouts on a single dial call. If we do want to support it in PipelineClient, we can either address it in this PR or postpone it to later. I will need to understand how to support it here though |
||
if err != nil { | ||
return err | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we may consider deprecating this one if you feel like it.
I went with a less aggressive approach and just documented the difference between the two.
it may seem a bit confusing in terms of the difference between them though...