Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 34 additions & 94 deletions lib/teleterm/apiserver/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,18 @@
package apiserver

import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"os"
"path/filepath"
"strings"

"github.com/gravitational/teleport/api/utils/keys"
api "github.com/gravitational/teleport/lib/teleterm/api/protogen/golang/v1"
"github.com/gravitational/teleport/lib/teleterm/apiserver/handler"

"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/trace"

"google.golang.org/grpc"
"google.golang.org/grpc/credentials"

"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
)

const (
// Server certificate file name (created by tsh), Connect expects exactly the same name
tshServerCertFileName = "tsh_server.crt"
// Client certificate file name (created by Connect)
clientCertFileName = "client.crt"
"google.golang.org/grpc"
)

// New creates an instance of API Server
Expand All @@ -48,29 +35,47 @@ func New(cfg Config) (*APIServer, error) {
return nil, trace.Wrap(err)
}

serviceHandler, err := handler.New(
handler.Config{
DaemonService: cfg.Daemon,
},
)
if err != nil {
return nil, trace.Wrap(err)
}
// Create the listener, set up the credentials and the server.

ls, err := newListener(cfg.HostAddr)
if err != nil {
return nil, trace.Wrap(err)
}

grpcCredentials, err := getGrpcCredentials(cfg)
serverOptions := []grpc.ServerOption{grpc.ChainUnaryInterceptor(withErrorHandling(cfg.Log))}
rendererCertPath := filepath.Join(cfg.CertsDir, rendererCertFileName)
tshdCertPath := filepath.Join(cfg.CertsDir, tshdCertFileName)
shouldUseMTLS := strings.HasPrefix(cfg.HostAddr, "tcp://")

if shouldUseMTLS {
tshdKeyPair, err := generateAndSaveCert(tshdCertPath)
if err != nil {
return nil, trace.Wrap(err)
}

// rendererCertPath will be read on an incoming client connection so we can assume that at this
// point the renderer process has saved its public key under that path.
withTshdCreds, err := createServerCredentials(tshdKeyPair, rendererCertPath)
if err != nil {
return nil, trace.Wrap(err)
}

serverOptions = append(serverOptions, withTshdCreds)
}

grpcServer := grpc.NewServer(serverOptions...)

// Create Terminal service.

serviceHandler, err := handler.New(
handler.Config{
DaemonService: cfg.Daemon,
},
)
if err != nil {
return nil, trace.Wrap(err)
}

grpcServer := grpc.NewServer(grpcCredentials, grpc.ChainUnaryInterceptor(
withErrorHandling(cfg.Log),
))

api.RegisterTerminalServiceServer(grpcServer, serviceHandler)

return &APIServer{cfg, ls, grpcServer}, nil
Expand Down Expand Up @@ -119,68 +124,3 @@ type APIServer struct {
// grpc is an instance of grpc server
grpcServer *grpc.Server
}

func getGrpcCredentials(cfg Config) (grpc.ServerOption, error) {
uri, err := utils.ParseAddr(cfg.HostAddr)

if err != nil {
return nil, trace.BadParameter("invalid host address: %s", cfg.HostAddr)
}

if uri.Network() != "unix" {
keyPair, err := generateKeyPair(cfg.CertsDir)
if err != nil {
return nil, trace.Wrap(err)
}

return grpc.Creds(keyPair), nil
}

return grpc.Creds(nil), nil
}

func generateKeyPair(certsDir string) (credentials.TransportCredentials, error) {
// File is first saved using under `tshServerCertTempPath` and then renamed to `tshServerCertFullPath`.
// It prevents Connect from reading half written file.
tshServerCertFullPath := filepath.Join(certsDir, tshServerCertFileName)
tshServerCertTempPath := tshServerCertFullPath + ".tmp"

cert, err := utils.GenerateSelfSignedCert([]string{"localhost"})
if err != nil {
return nil, trace.Wrap(err, "failed to generate a certificate")
}

err = os.WriteFile(tshServerCertTempPath, cert.Cert, 0600)
if err != nil {
return nil, trace.Wrap(err, "failed to save server certificate")
}

err = os.Rename(tshServerCertTempPath, tshServerCertFullPath)
if err != nil {
return nil, trace.Wrap(err, "failed to rename server certificate")
}

certificate, err := keys.X509KeyPair(cert.Cert, cert.PrivateKey)
if err != nil {
return nil, trace.Wrap(err, "failed to parse server certificates")
}

tlsConfig := &tls.Config{
GetConfigForClient: func(info *tls.ClientHelloInfo) (*tls.Config, error) {
caCert, err := os.ReadFile(filepath.Join(certsDir, clientCertFileName))
if err != nil {
return nil, trace.Wrap(err, "failed to read client certificate file")
}
caPool := x509.NewCertPool()
if !caPool.AppendCertsFromPEM(caCert) {
return nil, trace.Wrap(err, "failed to add client CA file")
}
return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{certificate},
ClientCAs: caPool,
}, nil
},
}
return credentials.NewTLS(tlsConfig), nil
}
103 changes: 103 additions & 0 deletions lib/teleterm/apiserver/grpccredentials.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package apiserver

import (
"crypto/tls"
"crypto/x509"
"os"
"path/filepath"

"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/utils"

"github.com/gravitational/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)

const (
// tshdCertFileName is the file name of the cert created by the tshd process. The Electron app
// expects it to exist under this name in the certs dir passed through a flag to tshd.
tshdCertFileName = "tshd.crt"
// rendererCertFileName is the file name of the cert created by the renderer process of the
// Electron app.
rendererCertFileName = "renderer.crt"
)

// createServerCredentials creates mTLS credentials for a gRPC server. The client cert file is read
// only on an incoming connection, not upfront. Otherwise we'd need to wait for the client cert file
// to exist before booting up the server.
func createServerCredentials(serverKeyPair tls.Certificate, clientCertPath string) (grpc.ServerOption, error) {
config := &tls.Config{
GetConfigForClient: func(_ *tls.ClientHelloInfo) (*tls.Config, error) {
clientCert, err := os.ReadFile(clientCertPath)
if err != nil {
return nil, trace.Wrap(err, "failed to read the client cert file")
}

certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(clientCert) {
return nil, trace.BadParameter("failed to add the client cert to the pool")
}

return &tls.Config{
ClientAuth: tls.RequireAndVerifyClientCert,
Certificates: []tls.Certificate{serverKeyPair},
ClientCAs: certPool,
}, nil
},
}

return grpc.Creds(credentials.NewTLS(config)), nil
}

func generateAndSaveCert(targetPath string) (tls.Certificate, error) {
// The cert is first saved under a temp path and then renamed to targetPath. This prevents other
// processes from reading a half-written file.
tempFile, err := os.CreateTemp(filepath.Dir(targetPath), filepath.Base(targetPath))
if err != nil {
return tls.Certificate{}, trace.Wrap(err)
}
defer os.Remove(tempFile.Name())

cert, err := utils.GenerateSelfSignedCert([]string{"localhost"})
if err != nil {
return tls.Certificate{}, trace.Wrap(err, "failed to generate the certificate")
}

if err = tempFile.Chmod(0600); err != nil {
return tls.Certificate{}, trace.Wrap(err)
}

if _, err = tempFile.Write(cert.Cert); err != nil {
return tls.Certificate{}, trace.Wrap(err)
}

if err = tempFile.Close(); err != nil {
return tls.Certificate{}, trace.Wrap(err)
}

if err = os.Rename(tempFile.Name(), targetPath); err != nil {
return tls.Certificate{}, trace.Wrap(err)
}

certificate, err := keys.X509KeyPair(cert.Cert, cert.PrivateKey)
if err != nil {
return tls.Certificate{}, trace.Wrap(err)
}

return certificate, nil
}