Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for custom dial function with timeouts #1669

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 65 additions & 30 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,15 @@ type Client struct {

// Callback for establishing new connections to hosts.
//
// Default Dial is used if not set.
// Default DialTimeout is used if not set.
DialTimeout DialFuncWithTimeout

// Callback for establishing new connections to hosts.
//
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
//
// If not set, DialTimeout is used.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may consider deprecating this one if you feel like it.
I went with a less aggressive approach and just documented the difference between the two.
it may seem a bit confusing in terms of the difference between them though...

Dial DialFunc

// Attempt to connect to both ipv4 and ipv6 addresses if set to true.
Expand Down Expand Up @@ -505,6 +513,7 @@ func (c *Client) Do(req *Request, resp *Response) error {
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
Dial: c.Dial,
DialTimeout: c.DialTimeout,
DialDualStack: c.DialDualStack,
IsTLS: isTLS,
TLSConfig: c.TLSConfig,
Expand Down Expand Up @@ -624,6 +633,21 @@ const DefaultMaxIdemponentCallAttempts = 5
// - foobar.com:8080
type DialFunc func(addr string) (net.Conn, error)

// DialFuncWithTimeout must establish connection to addr.
// Unlike DialFunc, it also accepts a timeout.
//
// There is no need in establishing TLS (SSL) connection for https.
// The client automatically converts connection to TLS
// if HostClient.IsTLS is set.
//
// TCP address passed to DialFuncWithTimeout always contains host and port.
// Example TCP addr values:
//
// - foobar.com:80
// - foobar.com:443
// - foobar.com:8080
type DialFuncWithTimeout func(addr string, timeout time.Duration) (net.Conn, error)

// RetryIfFunc signature of retry if function
//
// Request argument passed to RetryIfFunc, if there are any request errors.
Expand Down Expand Up @@ -656,7 +680,7 @@ type HostClient struct {
noCopy noCopy

// Comma-separated list of upstream HTTP server host addresses,
// which are passed to Dial in a round-robin manner.
// which are passed to Dial or DialTimeout in a round-robin manner.
//
// Each address may contain port if default dialer is used.
// For example,
Expand All @@ -673,16 +697,24 @@ type HostClient struct {
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool

// Callback for establishing new connection to the host.
// Callback for establishing new connections to hosts.
//
// Default Dial is used if not set.
// Default DialTimeout is used if not set.
DialTimeout DialFuncWithTimeout

// Callback for establishing new connections to hosts.
//
// Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
// If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
//
// If not set, DialTimeout is used.
Dial DialFunc

// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
// i.e. if Dial is blank.
// i.e. if Dial and DialTimeout are blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
Expand Down Expand Up @@ -1827,7 +1859,8 @@ func (c *HostClient) nextAddr() string {
}

func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err error) {
// use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or dial has been set.
// use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or if
// c.DialTimeout has not been set and c.Dial has been set.
// attempt to dial all the available hosts before giving up.

c.addrsLock.Lock()
Expand All @@ -1839,16 +1872,6 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
n = 1
}

dial := c.Dial
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this block is removed in favor of the new callDialFunc function

if dialTimeout != 0 && dial == nil {
dial = func(addr string) (net.Conn, error) {
if c.DialDualStack {
return DialDualStackTimeout(addr, dialTimeout)
}
return DialTimeout(addr, dialTimeout)
}
}

timeout := c.ReadTimeout + c.WriteTimeout
if timeout <= 0 {
timeout = DefaultDialTimeout
Expand All @@ -1857,7 +1880,7 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
conn, err = dialAddr(addr, dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
conn, err = dialAddr(addr, c.Dial, c.DialTimeout, c.DialDualStack, c.IsTLS, tlsConfig, dialTimeout, c.WriteTimeout)
if err == nil {
return conn, nil
}
Expand Down Expand Up @@ -1916,17 +1939,9 @@ func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.T
return conn, nil
}

func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
deadline := time.Now().Add(timeout)
if dial == nil {
if dialDualStack {
dial = DialDualStack
} else {
dial = Dial
}
addr = AddMissingPort(addr, isTLS)
}
conn, err := dial(addr)
func dialAddr(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, tlsConfig *tls.Config, dialTimeout, writeTimeout time.Duration) (net.Conn, error) {
deadline := time.Now().Add(writeTimeout)
conn, err := callDialFunc(addr, dial, dialWithTimeout, dialDualStack, isTLS, dialTimeout)
if err != nil {
return nil, err
}
Expand All @@ -1939,14 +1954,34 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
_, isTLSAlready := conn.(interface{ Handshake() error })

if isTLS && !isTLSAlready {
if timeout == 0 {
if writeTimeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, deadline)
}
return conn, nil
}

func callDialFunc(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, timeout time.Duration) (net.Conn, error) {
if dialWithTimeout != nil {
return dialWithTimeout(addr, timeout)
}
if dial != nil {
return dial(addr)
}
addr = AddMissingPort(addr, isTLS)
if timeout > 0 {
if dialDualStack {
return DialDualStackTimeout(addr, timeout)
}
return DialTimeout(addr, timeout)
}
if dialDualStack {
return DialDualStack(addr)
}
return Dial(addr)
}

// AddMissingPort adds a port to a host if it is missing.
// A literal IPv6 address in hostport must be enclosed in square
// brackets, as in "[::1]:80", "[::1%lo0]:80".
Expand Down Expand Up @@ -2591,7 +2626,7 @@ func (c *pipelineConnClient) init() {

func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to not overcomplicate the commit, i left PipelineClient as is - meaning although it does implement a DoTimeout function, it does not support passing it to the dial process and thus does not expose a DialWithTimeout param. Adding it could be quite complex due to the async nature of PipelineClient, the dial process might serve several different requests so to support it we might need to further design how exactly we want to act when we have several different timeouts on a single dial call.

If we do want to support it in PipelineClient, we can either address it in this PR or postpone it to later. I will need to understand how to support it here though

if err != nil {
return err
}
Expand Down
92 changes: 92 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) {
})
}
}

type clientDoTimeOuter interface {
DoTimeout(req *Request, resp *Response, timeout time.Duration) error
}

func TestDialTimeout(t *testing.T) {
t.Parallel()

tests := []struct {
name string
client clientDoTimeOuter
requestTimeout time.Duration
shouldFailFast bool
}{
{
name: "Client should fail after a millisecond due to request timeout",
client: &Client{
// should be ignored due to DialTimeout
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
// should be used
DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
time.Sleep(timeout)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: true,
},
{
name: "Client should fail after a second due to no DialTimeout set",
client: &Client{
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: false,
},
{
name: "HostClient should fail after a millisecond due to request timeout",
client: &HostClient{
// should be ignored due to DialTimeout
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
// should be used
DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
time.Sleep(timeout)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: true,
},
{
name: "HostClient should fail after a second due to no DialTimeout set",
client: &HostClient{
Dial: func(addr string) (net.Conn, error) {
time.Sleep(time.Second)
return nil, errors.New("timeout")
},
},
requestTimeout: time.Millisecond,
shouldFailFast: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
start := time.Now()
err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout)
if err == nil {
t.Fatal("expected error (timeout)")
}
if tt.shouldFailFast {
if time.Since(start) > time.Second {
t.Fatal("expected timeout after a millisecond")
}
} else {
if time.Since(start) < time.Second {
t.Fatal("expected timeout after a second")
}
}
})
}
}
8 changes: 4 additions & 4 deletions tcpdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Dial(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down Expand Up @@ -102,7 +102,7 @@ func DialDualStack(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down Expand Up @@ -199,7 +199,7 @@ func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down Expand Up @@ -253,7 +253,7 @@ func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
// to Client.Dial or HostClient.Dial.
// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
Expand Down
Loading