diff --git a/v8/client/client.go b/v8/client/client.go index 074e3f12..335132c6 100644 --- a/v8/client/client.go +++ b/v8/client/client.go @@ -19,6 +19,7 @@ import ( "github.com/jcmturner/gokrb5/v8/krberror" "github.com/jcmturner/gokrb5/v8/messages" "github.com/jcmturner/gokrb5/v8/types" + "golang.org/x/net/proxy" ) // Client side configuration and state. @@ -28,6 +29,7 @@ type Client struct { settings *Settings sessions *sessions cache *Cache + tcpDialer proxy.Dialer } // NewWithPassword creates a new client from a password credential. @@ -327,3 +329,7 @@ func (cl *Client) Print(w io.Writer) { k, _ := cl.Credentials.Keytab().JSON() fmt.Fprintf(w, "Keytab:\n%s\n", k) } + +func (cl *Client) SetDialerForTCP(dialer proxy.Dialer) { + cl.tcpDialer = dialer +} diff --git a/v8/client/network.go b/v8/client/network.go index 8a383c3f..7da99e87 100644 --- a/v8/client/network.go +++ b/v8/client/network.go @@ -10,6 +10,7 @@ import ( "github.com/jcmturner/gokrb5/v8/iana/errorcode" "github.com/jcmturner/gokrb5/v8/messages" + "golang.org/x/net/proxy" ) // SendToKDC performs network actions to send data to the KDC. @@ -132,7 +133,7 @@ func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) { if err != nil { return r, err } - r, err = dialSendTCP(kdcs, b) + r, err = dialSendTCP(kdcs, b, cl.tcpDialer) if err != nil { return r, err } @@ -140,10 +141,18 @@ func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) { } // dialKDCTCP establishes a TCP connection to a KDC. -func dialSendTCP(kdcs map[int]string, b []byte) ([]byte, error) { +func dialSendTCP(kdcs map[int]string, b []byte, dialer proxy.Dialer) ([]byte, error) { var errs []string for i := 1; i <= len(kdcs); i++ { - conn, err := net.DialTimeout("tcp", kdcs[i], 5*time.Second) + var conn net.Conn + var err error + + if dialer != nil { + conn, err = dialer.Dial("tcp", kdcs[i]) + } else { + conn, err = net.DialTimeout("tcp", kdcs[i], 5*time.Second) + } + if err != nil { errs = append(errs, fmt.Sprintf("error establishing connection to %s: %v", kdcs[i], err)) continue @@ -155,7 +164,7 @@ func dialSendTCP(kdcs map[int]string, b []byte) ([]byte, error) { // conn is guaranteed to be a TCPConn rb, err := sendTCP(conn.(*net.TCPConn), b) if err != nil { - errs = append(errs, fmt.Sprintf("error sneding to %s: %v", kdcs[i], err)) + errs = append(errs, fmt.Sprintf("error sending to %s: %v", kdcs[i], err)) continue } return rb, nil diff --git a/v8/client/passwd.go b/v8/client/passwd.go index fe20559c..22059f54 100644 --- a/v8/client/passwd.go +++ b/v8/client/passwd.go @@ -65,7 +65,7 @@ func (cl *Client) sendToKPasswd(msg kadmin.Request) (r kadmin.Reply, err error) return } } else { - rb, err = dialSendTCP(kps, b) + rb, err = dialSendTCP(kps, b, cl.tcpDialer) if err != nil { return }