diff --git a/Procfile b/Procfile index 6ead5c3f79429..65d868a6be75a 100644 --- a/Procfile +++ b/Procfile @@ -1,3 +1,3 @@ controller: go run ./cmd/argocd-application-controller/main.go --app-resync 10 -api-server: go run ./cmd/argocd-server/main.go +api-server: go run ./cmd/argocd-server/main.go --insecure repo-server: go run ./cmd/argocd-repo-server/main.go diff --git a/cmd/argocd-server/commands/root.go b/cmd/argocd-server/commands/root.go index f181ec764a4b5..df0eea8eb178d 100644 --- a/cmd/argocd-server/commands/root.go +++ b/cmd/argocd-server/commands/root.go @@ -15,6 +15,7 @@ import ( // NewCommand returns a new instance of an argocd command func NewCommand() *cobra.Command { var ( + insecure bool logLevel string clientConfig clientcmd.ClientConfig staticAssetsDir string @@ -39,12 +40,21 @@ func NewCommand() *cobra.Command { appclientset := appclientset.NewForConfigOrDie(config) repoclientset := reposerver.NewRepositoryServerClientset(repoServerAddress) - argocd := server.NewServer(kubeclientset, appclientset, repoclientset, namespace, staticAssetsDir) + argoCDOpts := server.ArgoCDServerOpts{ + Insecure: insecure, + Namespace: namespace, + StaticAssetsDir: staticAssetsDir, + KubeClientset: kubeclientset, + AppClientset: appclientset, + RepoClientset: repoclientset, + } + argocd := server.NewServer(argoCDOpts) argocd.Run() }, } clientConfig = cli.AddKubectlFlagsToCmd(command) + command.Flags().BoolVar(&insecure, "insecure", false, "Run server without TLS") command.Flags().StringVar(&staticAssetsDir, "staticassets", "", "Static assets directory path") command.Flags().StringVar(&logLevel, "loglevel", "info", "Set the logging level. One of: debug|info|warn|error") command.Flags().StringVar(&repoServerAddress, "repo-server", "localhost:8081", "Repo server address.") diff --git a/cmd/argocd/commands/root.go b/cmd/argocd/commands/root.go index 5fe00a9e72b0d..1410bc45b1872 100644 --- a/cmd/argocd/commands/root.go +++ b/cmd/argocd/commands/root.go @@ -30,7 +30,8 @@ func NewCommand() *cobra.Command { command.AddCommand(NewUninstallCommand()) command.PersistentFlags().StringVar(&clientOpts.ServerAddr, "server", "", "ArgoCD server address") - command.PersistentFlags().BoolVar(&clientOpts.Insecure, "insecure", true, "Disable transport security for the client connection") + command.PersistentFlags().BoolVar(&clientOpts.Insecure, "insecure", false, "Disable transport security for the client connection, including host verification") + command.PersistentFlags().StringVar(&clientOpts.CertFile, "server-crt", "", "Server certificate file") return command } diff --git a/install/install.go b/install/install.go index 1f0ac2861bedb..bdf64661f6c59 100644 --- a/install/install.go +++ b/install/install.go @@ -12,6 +12,7 @@ import ( "github.com/argoproj/argo-cd/util/kube" "github.com/argoproj/argo-cd/util/password" "github.com/argoproj/argo-cd/util/session" + tlsutil "github.com/argoproj/argo-cd/util/tls" "github.com/ghodss/yaml" "github.com/gobuffalo/packr" log "github.com/sirupsen/logrus" @@ -132,27 +133,40 @@ func (i *Installer) InstallSettings() { errors.CheckError(err) configManager := config.NewConfigManager(kubeclientset, i.Namespace) _, err = configManager.GetSettings() - if err != nil { - if !apierr.IsNotFound(err) { - log.Fatal(err) - } - // configmap/secret not yet created - signature, err := session.MakeSignature(32) - errors.CheckError(err) - passwordRaw := readAndConfirmPassword() - hashedPassword, err := password.HashPassword(passwordRaw) - errors.CheckError(err) - newSettings := config.ArgoCDSettings{ - ServerSignature: signature, - LocalUsers: map[string]string{ - common.ArgoCDAdminUsername: hashedPassword, - }, - } - err = configManager.SaveSettings(&newSettings) - errors.CheckError(err) - } else { + if err == nil { log.Infof("Settings already exists. Skipping creation") + return } + if !apierr.IsNotFound(err) { + log.Fatal(err) + } + // configmap/secret not yet created + var newSettings config.ArgoCDSettings + + // set JWT signature + signature, err := session.MakeSignature(32) + errors.CheckError(err) + newSettings.ServerSignature = signature + + // generate admin password + passwordRaw := readAndConfirmPassword() + hashedPassword, err := password.HashPassword(passwordRaw) + errors.CheckError(err) + newSettings.LocalUsers = map[string]string{ + common.ArgoCDAdminUsername: hashedPassword, + } + + // generate TLS cert + certOpts := tlsutil.CertOptions{ + Host: "argocd", + Organization: "Argo CD", + } + cert, err := tlsutil.GenerateX509KeyPair(certOpts) + errors.CheckError(err) + newSettings.Certificate = cert + + err = configManager.SaveSettings(&newSettings) + errors.CheckError(err) } func readAndConfirmPassword() string { diff --git a/pkg/apiclient/apiclient.go b/pkg/apiclient/apiclient.go index 12ff1c3ca2b58..35cee4d2dc6e2 100644 --- a/pkg/apiclient/apiclient.go +++ b/pkg/apiclient/apiclient.go @@ -1,14 +1,21 @@ package apiclient import ( + "context" + "crypto/tls" + "crypto/x509" "errors" + "fmt" + "io/ioutil" "os" "github.com/argoproj/argo-cd/server/application" "github.com/argoproj/argo-cd/server/cluster" "github.com/argoproj/argo-cd/server/repository" + grpc_util "github.com/argoproj/argo-cd/util/grpc" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" ) const ( @@ -29,6 +36,7 @@ type ServerClient interface { type ClientOptions struct { ServerAddr string Insecure bool + CertFile string } type client struct { @@ -57,15 +65,32 @@ func NewClientOrDie(opts *ClientOptions) ServerClient { } func (c *client) NewConn() (*grpc.ClientConn, error) { - var dialOpts []grpc.DialOption - if c.Insecure { - dialOpts = append(dialOpts, grpc.WithInsecure()) + var creds credentials.TransportCredentials + if c.CertFile != "" { + b, err := ioutil.ReadFile(c.CertFile) + if err != nil { + return nil, err + } + cp := x509.NewCertPool() + if !cp.AppendCertsFromPEM(b) { + return nil, fmt.Errorf("credentials: failed to append certificates") + } + tlsConfig := tls.Config{ + RootCAs: cp, + } + if c.Insecure { + tlsConfig.InsecureSkipVerify = true + } + creds = credentials.NewTLS(&tlsConfig) } else { - return nil, errors.New("secure authentication unsupported") - } // else if opts.Credentials != nil { - // dialOpts = append(dialOpts, grpc.WithTransportCredentials(opts.Credentials)) - //} - return grpc.Dial(c.ServerAddr, dialOpts...) + if c.Insecure { + tlsConfig := tls.Config{ + InsecureSkipVerify: true, + } + creds = credentials.NewTLS(&tlsConfig) + } + } + return grpc_util.BlockingDial(context.Background(), "tcp", c.ServerAddr, creds) } func (c *client) NewRepoClient() (*grpc.ClientConn, repository.RepositoryServiceClient, error) { diff --git a/server/server.go b/server/server.go index a90aa0cc972f7..ee7f90b92c676 100644 --- a/server/server.go +++ b/server/server.go @@ -2,12 +2,15 @@ package server import ( "context" + "crypto/tls" + "crypto/x509" "fmt" "net" "net/http" "strings" argocd "github.com/argoproj/argo-cd" + "github.com/argoproj/argo-cd/errors" appclientset "github.com/argoproj/argo-cd/pkg/client/clientset/versioned" "github.com/argoproj/argo-cd/reposerver" "github.com/argoproj/argo-cd/server/application" @@ -18,12 +21,14 @@ import ( "github.com/argoproj/argo-cd/util/config" grpc_util "github.com/argoproj/argo-cd/util/grpc" jsonutil "github.com/argoproj/argo-cd/util/json" + tlsutil "github.com/argoproj/argo-cd/util/tls" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus" "github.com/grpc-ecosystem/grpc-gateway/runtime" log "github.com/sirupsen/logrus" "github.com/soheilhy/cmux" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "k8s.io/client-go/kubernetes" ) @@ -37,31 +42,32 @@ var ( // ArgoCDServer is the API server for ArgoCD type ArgoCDServer struct { - ns string - staticAssetsDir string - kubeclientset kubernetes.Interface - appclientset appclientset.Interface - repoclientset reposerver.Clientset - settings config.ArgoCDSettings - log *log.Entry + ArgoCDServerOpts + + settings config.ArgoCDSettings + log *log.Entry +} + +type ArgoCDServerOpts struct { + Insecure bool + Namespace string + StaticAssetsDir string + KubeClientset kubernetes.Interface + AppClientset appclientset.Interface + RepoClientset reposerver.Clientset } // NewServer returns a new instance of the ArgoCD API server -func NewServer( - kubeclientset kubernetes.Interface, appclientset appclientset.Interface, repoclientset reposerver.Clientset, namespace, staticAssetsDir string) *ArgoCDServer { - configManager := config.NewConfigManager(kubeclientset, namespace) +func NewServer(opts ArgoCDServerOpts) *ArgoCDServer { + configManager := config.NewConfigManager(opts.KubeClientset, opts.Namespace) settings, err := configManager.GetSettings() if err != nil { log.Fatal(err) } return &ArgoCDServer{ - ns: namespace, - kubeclientset: kubeclientset, - appclientset: appclientset, - repoclientset: repoclientset, - log: log.NewEntry(log.New()), - staticAssetsDir: staticAssetsDir, - settings: *settings, + ArgoCDServerOpts: opts, + log: log.NewEntry(log.New()), + settings: *settings, } } @@ -74,39 +80,121 @@ func (a *ArgoCDServer) Run() { ctx, cancel := context.WithCancel(ctx) defer cancel() - conn, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) - if err != nil { - panic(err) + grpcS := a.newGRPCServer() + var httpS *http.Server + var httpsS *http.Server + if a.useTLS() { + httpS = newRedirectServer() + httpsS = a.newHTTPServer(ctx) + } else { + httpS = a.newHTTPServer(ctx) } // Cmux is used to support servicing gRPC and HTTP1.1+JSON on the same port - m := cmux.New(conn) - grpcL := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) - httpL := m.Match(cmux.HTTP1Fast()) - - // gRPC Server - grpcS := grpc.NewServer( - grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( - grpc_logrus.StreamServerInterceptor(a.log), - grpc_util.PanicLoggerStreamServerInterceptor(a.log), - )), - grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( - grpc_logrus.UnaryServerInterceptor(a.log), - grpc_util.PanicLoggerUnaryServerInterceptor(a.log), - )), - ) + conn, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) + errors.CheckError(err) + + tcpm := cmux.New(conn) + var tlsm cmux.CMux + var grpcL net.Listener + var httpL net.Listener + var httpsL net.Listener + if !a.useTLS() { + httpL = tcpm.Match(cmux.HTTP1Fast()) + grpcL = tcpm.Match(cmux.HTTP2HeaderField("content-type", "application/grpc")) + } else { + // We first match on HTTP 1.1 methods. + httpL = tcpm.Match(cmux.HTTP1Fast()) + + // If not matched, we assume that its TLS. + tlsl := tcpm.Match(cmux.Any()) + tlsConfig := tls.Config{ + Certificates: []tls.Certificate{*a.settings.Certificate}, + } + tlsl = tls.NewListener(tlsl, &tlsConfig) + + // Now, we build another mux recursively to match HTTPS and GoRPC. + tlsm = cmux.New(tlsl) + httpsL = tlsm.Match(cmux.HTTP1Fast()) + grpcL = tlsm.Match(cmux.Any()) + } + + // Start the muxed listeners for our servers + log.Infof("argocd %s serving on port %d (tls: %v, namespace: %s)", argocd.GetVersion(), port, a.useTLS(), a.Namespace) + go func() { errors.CheckError(grpcS.Serve(grpcL)) }() + go func() { errors.CheckError(httpS.Serve(httpL)) }() + if a.useTLS() { + go func() { errors.CheckError(httpsS.Serve(httpsL)) }() + go func() { errors.CheckError(tlsm.Serve()) }() + } + err = tcpm.Serve() + errors.CheckError(err) +} + +func (a *ArgoCDServer) useTLS() bool { + if a.Insecure || a.settings.Certificate == nil { + return false + } + return true +} + +func (a *ArgoCDServer) newGRPCServer() *grpc.Server { + var sOpts []grpc.ServerOption + // NOTE: notice we do not configure the gRPC server here with TLS (e.g. grpc.Creds(creds)) + // This is because TLS handshaking occurs in cmux handling + sOpts = append(sOpts, grpc.StreamInterceptor(grpc_middleware.ChainStreamServer( + grpc_logrus.StreamServerInterceptor(a.log), + grpc_util.PanicLoggerStreamServerInterceptor(a.log), + ))) + sOpts = append(sOpts, grpc.UnaryInterceptor(grpc_middleware.ChainUnaryServer( + grpc_logrus.UnaryServerInterceptor(a.log), + grpc_util.PanicLoggerUnaryServerInterceptor(a.log), + ))) + + grpcS := grpc.NewServer(sOpts...) + clusterService := cluster.NewServer(a.Namespace, a.KubeClientset, a.AppClientset) + repoService := repository.NewServer(a.Namespace, a.KubeClientset, a.AppClientset) + sessionService := session.NewServer(a.Namespace, a.KubeClientset, a.AppClientset, a.settings) + applicationService := application.NewServer(a.Namespace, a.KubeClientset, a.AppClientset, a.RepoClientset, repoService, clusterService) version.RegisterVersionServiceServer(grpcS, &version.Server{}) - clusterService := cluster.NewServer(a.ns, a.kubeclientset, a.appclientset) - repoService := repository.NewServer(a.ns, a.kubeclientset, a.appclientset) - sessionService := session.NewServer(a.ns, a.kubeclientset, a.appclientset, a.settings) cluster.RegisterClusterServiceServer(grpcS, clusterService) - application.RegisterApplicationServiceServer(grpcS, application.NewServer(a.ns, a.kubeclientset, a.appclientset, a.repoclientset, repoService, clusterService)) + application.RegisterApplicationServiceServer(grpcS, applicationService) repository.RegisterRepositoryServiceServer(grpcS, repoService) session.RegisterSessionServiceServer(grpcS, sessionService) + return grpcS +} + +// newHTTPServer returns the HTTP server to serve HTTP/HTTPS requests. This is implemented +// using grpc-gateway as a proxy to the gRPC server. +func (a *ArgoCDServer) newHTTPServer(ctx context.Context) *http.Server { + mux := http.NewServeMux() + httpS := http.Server{ + Addr: endpoint, + Handler: mux, + } + var dOpts []grpc.DialOption + if a.useTLS() { + // The following sets up the dial Options for grpc-gateway to talk to gRPC server over TLS. + // grpc-gateway is just translating HTTP/HTTPS requests as gRPC requests over localhost, + // so we need to supply the same certificates to establish the connections that a normal, + // external gRPC client would need. + certPool := x509.NewCertPool() + pemCertBytes, _ := tlsutil.EncodeX509KeyPair(*a.settings.Certificate) + ok := certPool.AppendCertsFromPEM(pemCertBytes) + if !ok { + panic("bad certs") + } + dCreds := credentials.NewTLS(&tls.Config{ + RootCAs: certPool, + InsecureSkipVerify: true, + }) + dOpts = append(dOpts, grpc.WithTransportCredentials(dCreds)) + } else { + dOpts = append(dOpts, grpc.WithInsecure()) + } // HTTP 1.1+JSON Server // grpc-ecosystem/grpc-gateway is used to proxy HTTP requests to the corresponding gRPC call - mux := http.NewServeMux() // NOTE: if a marshaller option is not supplied, grpc-gateway will default to the jsonpb from // golang/protobuf. Which does not support types such as time.Time. gogo/protobuf does support // time.Time, but does not support custom UnmarshalJSON() and MarshalJSON() methods. Therefore @@ -114,45 +202,46 @@ func (a *ArgoCDServer) Run() { gwMuxOpts := runtime.WithMarshalerOption(runtime.MIMEWildcard, new(jsonutil.JSONMarshaler)) gwmux := runtime.NewServeMux(gwMuxOpts) mux.Handle("/api/", gwmux) - dOpts := []grpc.DialOption{grpc.WithInsecure()} mustRegisterGWHandler(version.RegisterVersionServiceHandlerFromEndpoint, ctx, gwmux, endpoint, dOpts) mustRegisterGWHandler(cluster.RegisterClusterServiceHandlerFromEndpoint, ctx, gwmux, endpoint, dOpts) mustRegisterGWHandler(application.RegisterApplicationServiceHandlerFromEndpoint, ctx, gwmux, endpoint, dOpts) mustRegisterGWHandler(repository.RegisterRepositoryServiceHandlerFromEndpoint, ctx, gwmux, endpoint, dOpts) mustRegisterGWHandler(session.RegisterSessionServiceHandlerFromEndpoint, ctx, gwmux, endpoint, dOpts) - if a.staticAssetsDir != "" { + if a.StaticAssetsDir != "" { mux.HandleFunc("/", func(writer http.ResponseWriter, request *http.Request) { - acceptHtml := false + acceptHTML := false for _, acceptType := range strings.Split(request.Header.Get("Accept"), ",") { if acceptType == "text/html" || acceptType == "html" { - acceptHtml = true + acceptHTML = true break } } fileRequest := request.URL.Path != "/index.html" && strings.Contains(request.URL.Path, ".") // serve index.html for non file requests to support HTML5 History API - if acceptHtml && !fileRequest && (request.Method == "GET" || request.Method == "HEAD") { - http.ServeFile(writer, request, a.staticAssetsDir+"/index.html") + if acceptHTML && !fileRequest && (request.Method == "GET" || request.Method == "HEAD") { + http.ServeFile(writer, request, a.StaticAssetsDir+"/index.html") } else { - http.ServeFile(writer, request, a.staticAssetsDir+request.URL.Path) + http.ServeFile(writer, request, a.StaticAssetsDir+request.URL.Path) } }) } + return &httpS +} - httpS := &http.Server{ - Addr: endpoint, - Handler: mux, - } - - // Start the muxed listeners for our servers - log.Infof("argocd %s serving on port %d (namespace: %s)", argocd.GetVersion(), port, a.ns) - go func() { _ = grpcS.Serve(grpcL) }() - go func() { _ = httpS.Serve(httpL) }() - err = m.Serve() - if err != nil { - panic(err) +// newRedirectServer returns an HTTP server which does a 307 redirect to the HTTPS server +func newRedirectServer() *http.Server { + return &http.Server{ + Addr: endpoint, + Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + target := "https://" + req.Host + req.URL.Path + if len(req.URL.RawQuery) > 0 { + target += "?" + req.URL.RawQuery + } + log.Printf("redirect to: %s", target) + http.Redirect(w, req, target, http.StatusTemporaryRedirect) + }), } } diff --git a/util/config/configmanager.go b/util/config/configmanager.go index b0dcb2875ed04..dc5399a4aedb2 100644 --- a/util/config/configmanager.go +++ b/util/config/configmanager.go @@ -1,9 +1,11 @@ package config import ( + "crypto/tls" "fmt" "github.com/argoproj/argo-cd/common" + tlsutil "github.com/argoproj/argo-cd/util/tls" apiv1 "k8s.io/api/core/v1" apierr "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -17,6 +19,10 @@ type ArgoCDSettings struct { // ServerSignature holds the key used to generate JWT tokens. ServerSignature []byte + + // Certificate holds the certificate/private key for the ArgoCD API server. + // If nil, will run insecure without TLS. + Certificate *tls.Certificate } const ( @@ -25,6 +31,12 @@ const ( // configManagerServerSignatureKey designates the key for a server secret key inside a Kubernetes secret. configManagerServerSignatureKey = "server.secretkey" + + // configManagerServerCertificate designates the key for the public cert used in TLS + configManagerServerCertificate = "server.crt" + + // configManagerServerPrivateKey designates the key for the private key used in TLS + configManagerServerPrivateKey = "server.key" ) // ConfigManager holds config info for a new manager with which to access Kubernetes ConfigMaps. @@ -59,6 +71,16 @@ func (mgr *ConfigManager) GetSettings() (*ArgoCDSettings, error) { return nil, fmt.Errorf("server secret key not found") } settings.ServerSignature = secretKey + + serverCert, certOk := argoCDSecret.Data[configManagerServerCertificate] + serverKey, keyOk := argoCDSecret.Data[configManagerServerPrivateKey] + if certOk && keyOk { + cert, err := tls.X509KeyPair(serverCert, serverKey) + if err != nil { + return nil, fmt.Errorf("invalid x509 key pair %s/%s in secret: %s", configManagerServerCertificate, configManagerServerPrivateKey, err) + } + settings.Certificate = &cert + } return &settings, nil } @@ -88,6 +110,11 @@ func (mgr *ConfigManager) SaveSettings(settings *ArgoCDSettings) error { configManagerServerSignatureKey: string(settings.ServerSignature), configManagerAdminPasswordKey: settings.LocalUsers[common.ArgoCDAdminUsername], } + if settings.Certificate != nil { + certBytes, keyBytes := tlsutil.EncodeX509KeyPair(*settings.Certificate) + secretStringData[configManagerServerCertificate] = string(certBytes) + secretStringData[configManagerServerPrivateKey] = string(keyBytes) + } argoCDSecret, err := mgr.clientset.CoreV1().Secrets(mgr.namespace).Get(common.ArgoCDSecretName, metav1.GetOptions{}) if err != nil { if !apierr.IsNotFound(err) { diff --git a/util/grpc/grpc.go b/util/grpc/grpc.go index fde7e85b5fb87..45686329228dd 100644 --- a/util/grpc/grpc.go +++ b/util/grpc/grpc.go @@ -1,12 +1,15 @@ package grpc import ( + "net" "runtime/debug" + "time" "github.com/sirupsen/logrus" "golang.org/x/net/context" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" ) // PanicLoggerUnaryServerInterceptor returns a new unary server interceptor for recovering from panics and returning error @@ -34,3 +37,72 @@ func PanicLoggerStreamServerInterceptor(log *logrus.Entry) grpc.StreamServerInte return handler(srv, stream) } } + +// BlockingDial is a helper method to dial the given address, using optional TLS credentials, +// and blocking until the returned connection is ready. If the given credentials are nil, the +// connection will be insecure (plain-text). +// Lifted from: https://github.com/fullstorydev/grpcurl/blob/master/grpcurl.go +func BlockingDial(ctx context.Context, network, address string, creds credentials.TransportCredentials, opts ...grpc.DialOption) (*grpc.ClientConn, error) { + // grpc.Dial doesn't provide any information on permanent connection errors (like + // TLS handshake failures). So in order to provide good error messages, we need a + // custom dialer that can provide that info. That means we manage the TLS handshake. + result := make(chan interface{}, 1) + + writeResult := func(res interface{}) { + // non-blocking write: we only need the first result + select { + case result <- res: + default: + } + } + + dialer := func(address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + conn, err := (&net.Dialer{Cancel: ctx.Done()}).Dial(network, address) + if err != nil { + writeResult(err) + return nil, err + } + if creds != nil { + conn, _, err = creds.ClientHandshake(ctx, address, conn) + if err != nil { + writeResult(err) + return nil, err + } + } + return conn, nil + } + + // Even with grpc.FailOnNonTempDialError, this call will usually timeout in + // the face of TLS handshake errors. So we can't rely on grpc.WithBlock() to + // know when we're done. So we run it in a goroutine and then use result + // channel to either get the channel or fail-fast. + go func() { + opts = append(opts, + grpc.WithBlock(), + grpc.FailOnNonTempDialError(true), + grpc.WithDialer(dialer), + grpc.WithInsecure(), // we are handling TLS, so tell grpc not to + ) + conn, err := grpc.DialContext(ctx, address, opts...) + var res interface{} + if err != nil { + res = err + } else { + res = conn + } + writeResult(res) + }() + + select { + case res := <-result: + if conn, ok := res.(*grpc.ClientConn); ok { + return conn, nil + } + return nil, res.(error) + case <-ctx.Done(): + return nil, ctx.Err() + } +} diff --git a/util/tls/tls.go b/util/tls/tls.go new file mode 100644 index 0000000000000..db308e797ea44 --- /dev/null +++ b/util/tls/tls.go @@ -0,0 +1,182 @@ +package tls + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "strings" + "time" +) + +const ( + DefaultRSABits = 2048 +) + +type CertOptions struct { + // Comma-separated hostnames and IPs to generate a certificate for + Host string + // Name of organization in certificate + Organization string + // Creation date + ValidFrom time.Time + // Duration that certificate is valid for + ValidFor time.Duration + // whether this cert should be its own Certificate Authority + IsCA bool + // Size of RSA key to generate. Ignored if --ecdsa-curve is set + RSABits int + // ECDSA curve to use to generate a key. Valid values are P224, P256 (recommended), P384, P521 + ECDSACurve string +} + +func publicKey(priv interface{}) interface{} { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &k.PublicKey + case *ecdsa.PrivateKey: + return &k.PublicKey + default: + return nil + } +} + +func pemBlockForKey(priv interface{}) *pem.Block { + switch k := priv.(type) { + case *rsa.PrivateKey: + return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(k)} + case *ecdsa.PrivateKey: + b, err := x509.MarshalECPrivateKey(k) + if err != nil { + fmt.Fprintf(os.Stderr, "Unable to marshal ECDSA private key: %v", err) + os.Exit(2) + } + return &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} + default: + return nil + } +} + +func generate(opts CertOptions) ([]byte, crypto.PrivateKey, error) { + if opts.Host == "" { + return nil, nil, fmt.Errorf("host not supplied") + } + + var privateKey crypto.PrivateKey + var err error + switch opts.ECDSACurve { + case "": + rsaBits := DefaultRSABits + if opts.RSABits != 0 { + rsaBits = opts.RSABits + } + privateKey, err = rsa.GenerateKey(rand.Reader, rsaBits) + case "P224": + privateKey, err = ecdsa.GenerateKey(elliptic.P224(), rand.Reader) + case "P256": + privateKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + case "P384": + privateKey, err = ecdsa.GenerateKey(elliptic.P384(), rand.Reader) + case "P521": + privateKey, err = ecdsa.GenerateKey(elliptic.P521(), rand.Reader) + default: + return nil, nil, fmt.Errorf("Unrecognized elliptic curve: %q", opts.ECDSACurve) + } + if err != nil { + return nil, nil, fmt.Errorf("failed to generate private key: %s", err) + } + + var notBefore time.Time + if opts.ValidFrom.IsZero() { + notBefore = time.Now() + } else { + notBefore = opts.ValidFrom + } + var validFor time.Duration + if opts.ValidFor == 0 { + validFor = 365 * 24 * time.Hour + } + notAfter := notBefore.Add(validFor) + + serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) + serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate serial number: %s", err) + } + + if opts.Organization == "" { + return nil, nil, fmt.Errorf("organization not supplied") + } + template := x509.Certificate{ + SerialNumber: serialNumber, + Subject: pkix.Name{ + Organization: []string{opts.Organization}, + }, + NotBefore: notBefore, + NotAfter: notAfter, + + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + hosts := strings.Split(opts.Host, ",") + for _, h := range hosts { + if ip := net.ParseIP(h); ip != nil { + template.IPAddresses = append(template.IPAddresses, ip) + } else { + template.DNSNames = append(template.DNSNames, h) + } + } + + if opts.IsCA { + template.IsCA = true + template.KeyUsage |= x509.KeyUsageCertSign + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, publicKey(privateKey), privateKey) + if err != nil { + return nil, nil, fmt.Errorf("Failed to create certificate: %s", err) + } + return certBytes, privateKey, nil +} + +// generatePEM generates a new certificate and key and returns it as PEM encoded bytes +func generatePEM(opts CertOptions) ([]byte, []byte, error) { + certBytes, privateKey, err := generate(opts) + if err != nil { + return nil, nil, err + } + certpem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certBytes}) + keypem := pem.EncodeToMemory(pemBlockForKey(privateKey)) + return certpem, keypem, nil +} + +// GenerateX509KeyPair generates a X509 key pair +func GenerateX509KeyPair(opts CertOptions) (*tls.Certificate, error) { + certpem, keypem, err := generatePEM(opts) + if err != nil { + return nil, err + } + cert, err := tls.X509KeyPair(certpem, keypem) + if err != nil { + return nil, err + } + return &cert, nil +} + +// EncodeX509KeyPair encodes a TLS Certificate into its pem encoded for storage +func EncodeX509KeyPair(cert tls.Certificate) ([]byte, []byte) { + certpem := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]}) + keypem := pem.EncodeToMemory(pemBlockForKey(cert.PrivateKey)) + return certpem, keypem +}