Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion go/vt/topo/zk2topo/zk_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ limitations under the License.
package zk2topo

import (
"crypto/tls"
"crypto/x509"
"flag"
"fmt"
"io/ioutil"
"math/rand"
"net"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -50,6 +54,10 @@ var (
maxConcurrency = flag.Int("topo_zk_max_concurrency", 64, "maximum number of pending requests to send to a Zookeeper server.")

baseTimeout = flag.Duration("topo_zk_base_timeout", 30*time.Second, "zk base timeout (see zk.Connect)")

certPath = flag.String("topo_zk_tls_cert", "", "the cert to use to connect to the zk topo server, requires topo_zk_tls_key, enables TLS")
keyPath = flag.String("topo_zk_tls_key", "", "the key to use to connect to the zk topo server, enables TLS")
caPath = flag.String("topo_zk_tls_ca", "", "the server ca to use to validate servers when connecting to the zk topo server")
)

// Time returns a time.Time from a ZK int64 milliseconds since Epoch time.
Expand Down Expand Up @@ -304,8 +312,46 @@ func dialZk(ctx context.Context, addr string) (*zk.Conn, <-chan zk.Event, error)
return nil, nil, err
}

options := zk.WithDialer(net.DialTimeout)
// If TLS is enabled use a TLS enabled dialer option
if *certPath != "" && *keyPath != "" {
if strings.Contains(addr, ",") {
log.Fatalf("This TLS zk code requires that the all the zk servers validate to a single server name.")
}

serverName := strings.Split(addr, ":")[0]

log.Infof("Using TLS ZK, connecting to %v server name %v", addr, serverName)
cert, err := tls.LoadX509KeyPair(*certPath, *keyPath)
if err != nil {
log.Fatalf("Unable to load cert %v and key %v, err %v", *certPath, *keyPath, err)
}

clientCACert, err := ioutil.ReadFile(*caPath)
if err != nil {
log.Fatalf("Unable to open ca cert %v, err %v", *caPath, err)
}

clientCertPool := x509.NewCertPool()
clientCertPool.AppendCertsFromPEM(clientCACert)

tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
RootCAs: clientCertPool,
ServerName: serverName,
}

tlsConfig.BuildNameToCertificate()

options = zk.WithDialer(func(network, address string, timeout time.Duration) (net.Conn, error) {
d := net.Dialer{Timeout: timeout}

return tls.DialWithDialer(&d, network, address, tlsConfig)
})
}

// zk.Connect automatically shuffles the servers
zconn, session, err := zk.Connect(servers, *baseTimeout)
zconn, session, err := zk.Connect(servers, *baseTimeout, options)
if err != nil {
return nil, nil, err
}
Expand Down