From f6feddbbf9386441c5dd0c06808aae39d7352d7a Mon Sep 17 00:00:00 2001 From: Michael Pawliszyn Date: Mon, 14 May 2018 14:41:00 -0400 Subject: [PATCH] Adds the option for TLS enabled zk topo server connections. Signed-off-by: Michael Pawliszyn --- go/vt/topo/zk2topo/zk_conn.go | 48 ++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/go/vt/topo/zk2topo/zk_conn.go b/go/vt/topo/zk2topo/zk_conn.go index 507f5c02569..661068c7854 100644 --- a/go/vt/topo/zk2topo/zk_conn.go +++ b/go/vt/topo/zk2topo/zk_conn.go @@ -17,9 +17,13 @@ limitations under the License. package zk2topo import ( + "crypto/tls" + "crypto/x509" "flag" "fmt" + "io/ioutil" "math/rand" + "net" "strings" "sync" "time" @@ -50,6 +54,10 @@ var ( maxConcurrency = flag.Int("topo_zk_max_concurrency", 64, "maximum number of pending requests to send to a Zookeeper server.") baseTimeout = flag.Duration("topo_zk_base_timeout", 30*time.Second, "zk base timeout (see zk.Connect)") + + certPath = flag.String("topo_zk_tls_cert", "", "the cert to use to connect to the zk topo server, requires topo_zk_tls_key, enables TLS") + keyPath = flag.String("topo_zk_tls_key", "", "the key to use to connect to the zk topo server, enables TLS") + caPath = flag.String("topo_zk_tls_ca", "", "the server ca to use to validate servers when connecting to the zk topo server") ) // Time returns a time.Time from a ZK int64 milliseconds since Epoch time. @@ -304,8 +312,46 @@ func dialZk(ctx context.Context, addr string) (*zk.Conn, <-chan zk.Event, error) return nil, nil, err } + options := zk.WithDialer(net.DialTimeout) + // If TLS is enabled use a TLS enabled dialer option + if *certPath != "" && *keyPath != "" { + if strings.Contains(addr, ",") { + log.Fatalf("This TLS zk code requires that the all the zk servers validate to a single server name.") + } + + serverName := strings.Split(addr, ":")[0] + + log.Infof("Using TLS ZK, connecting to %v server name %v", addr, serverName) + cert, err := tls.LoadX509KeyPair(*certPath, *keyPath) + if err != nil { + log.Fatalf("Unable to load cert %v and key %v, err %v", *certPath, *keyPath, err) + } + + clientCACert, err := ioutil.ReadFile(*caPath) + if err != nil { + log.Fatalf("Unable to open ca cert %v, err %v", *caPath, err) + } + + clientCertPool := x509.NewCertPool() + clientCertPool.AppendCertsFromPEM(clientCACert) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: clientCertPool, + ServerName: serverName, + } + + tlsConfig.BuildNameToCertificate() + + options = zk.WithDialer(func(network, address string, timeout time.Duration) (net.Conn, error) { + d := net.Dialer{Timeout: timeout} + + return tls.DialWithDialer(&d, network, address, tlsConfig) + }) + } + // zk.Connect automatically shuffles the servers - zconn, session, err := zk.Connect(servers, *baseTimeout) + zconn, session, err := zk.Connect(servers, *baseTimeout, options) if err != nil { return nil, nil, err }