Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions lib/teleterm/clusters/cluster_databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ func (c *Cluster) GetDatabases(ctx context.Context, r *api.GetDatabasesRequest)
}

// ReissueDBCerts issues new certificates for specific DB access
func (c *Cluster) ReissueDBCerts(ctx context.Context, user string, db types.Database) error {
func (c *Cluster) ReissueDBCerts(ctx context.Context, routeToDatabase tlsca.RouteToDatabase) error {
// When generating certificate for MongoDB access, database username must
// be encoded into it. This is required to be able to tell which database
// user to authenticate the connection as.
if db.GetProtocol() == libdefaults.ProtocolMongoDB && user == "" {
return trace.BadParameter("please provide the database user name using --db-user flag")
if routeToDatabase.Protocol == libdefaults.ProtocolMongoDB && routeToDatabase.Username == "" {
return trace.BadParameter("the username must be present for MongoDB connections")
}

err := addMetadataToRetryableError(ctx, func() error {
Expand All @@ -177,9 +177,9 @@ func (c *Cluster) ReissueDBCerts(ctx context.Context, user string, db types.Data
err = c.clusterClient.ReissueUserCerts(ctx, client.CertCacheKeep, client.ReissueParams{
RouteToCluster: c.clusterClient.SiteName,
RouteToDatabase: proto.RouteToDatabase{
ServiceName: db.GetName(),
Protocol: db.GetProtocol(),
Username: user,
ServiceName: routeToDatabase.ServiceName,
Protocol: routeToDatabase.Protocol,
Username: routeToDatabase.Username,
},
AccessRequests: c.status.ActiveRequests.AccessRequests,
})
Expand All @@ -194,11 +194,7 @@ func (c *Cluster) ReissueDBCerts(ctx context.Context, user string, db types.Data
}

// Update the database-specific connection profile file.
err = dbprofile.Add(ctx, c.clusterClient, tlsca.RouteToDatabase{
ServiceName: db.GetName(),
Protocol: db.GetProtocol(),
Username: user,
}, c.status)
err = dbprofile.Add(ctx, c.clusterClient, routeToDatabase, c.status)
if err != nil {
return trace.Wrap(err)
}
Expand Down
18 changes: 14 additions & 4 deletions lib/teleterm/clusters/cluster_gateways.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ package clusters
import (
"context"

"github.com/gravitational/teleport/lib/teleterm/gateway"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/teleterm/gateway"
"github.com/gravitational/teleport/lib/tlsca"
)

type CreateGatewayParams struct {
Expand All @@ -36,6 +37,7 @@ type CreateGatewayParams struct {
LocalPort string
CLICommandProvider gateway.CLICommandProvider
TCPPortAllocator gateway.TCPPortAllocator
OnExpiredCert gateway.OnExpiredCertFunc
}

// CreateGateway creates a gateway
Expand All @@ -45,7 +47,13 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams)
return nil, trace.Wrap(err)
}

if err := c.ReissueDBCerts(ctx, params.TargetUser, db); err != nil {
routeToDatabase := tlsca.RouteToDatabase{
ServiceName: db.GetName(),
Protocol: db.GetProtocol(),
Username: params.TargetUser,
}

if err := c.ReissueDBCerts(ctx, routeToDatabase); err != nil {
return nil, trace.Wrap(err)
}

Expand All @@ -60,9 +68,11 @@ func (c *Cluster) CreateGateway(ctx context.Context, params CreateGatewayParams)
CertPath: c.status.DatabaseCertPathForCluster(c.clusterClient.SiteName, db.GetName()),
Insecure: c.clusterClient.InsecureSkipVerify,
WebProxyAddr: c.clusterClient.WebProxyAddr,
Log: c.Log.WithField("gateway", params.TargetURI),
Log: c.Log,
CLICommandProvider: params.CLICommandProvider,
TCPPortAllocator: params.TCPPortAllocator,
OnExpiredCert: params.OnExpiredCert,
Clock: c.clock,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
2 changes: 1 addition & 1 deletion lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (*gateway.Ga
return oldGateway, nil
}

newGateway, err := gateway.NewWithLocalPort(*oldGateway, localPort)
newGateway, err := gateway.NewWithLocalPort(oldGateway, localPort)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
4 changes: 3 additions & 1 deletion lib/teleterm/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters.
return nil, trace.Wrap(err)
}
m.t.Cleanup(func() {
gateway.Close()
if err := gateway.Close(); err != nil {
m.t.Logf("Ignoring error from gateway.Close() during cleanup, it appears the gateway was already closed. The error was: %s", err)
}
})

return gateway, nil
Expand Down
49 changes: 43 additions & 6 deletions lib/teleterm/gateway/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ limitations under the License.
package gateway

import (
"context"
"runtime"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/teleterm/api/uri"
"github.com/gravitational/trace"

"github.com/google/uuid"
"github.com/sirupsen/logrus"
"github.com/gravitational/teleport/lib/tlsca"
)

// Config describes gateway configuration
Expand All @@ -39,7 +42,7 @@ type Config struct {
// TargetUser is the target user name
TargetUser string
// TargetSubresourceName points at a subresource of the remote resource, for example a database
// name on a database server.
// name on a database server. It is used only for generating the CLI command.
TargetSubresourceName string

// Port is the gateway port
Expand All @@ -63,8 +66,21 @@ type Config struct {
// TCPPortAllocator creates listeners on the given ports. This interface lets us avoid occupying
// hardcoded ports in tests.
TCPPortAllocator TCPPortAllocator
// Clock is used by Gateway.localProxy to check cert expiration.
Clock clockwork.Clock
// OnExpiredCert is called when a new downstream connection is accepted by the
// gateway but cannot be proxied because the cert used by the gateway has expired.
//
// Handling of the connection is blocked until OnExpiredCert returns.
OnExpiredCert OnExpiredCertFunc
}

// OnExpiredCertFunc is the type of a function that is called when a new downstream connection is
// accepted by the gateway but cannot be proxied because the cert used by the gateway has expired.
//
// Handling of the connection is blocked until the function returns.
type OnExpiredCertFunc func(context.Context, *Gateway) error

// CheckAndSetDefaults checks and sets the defaults
func (c *Config) CheckAndSetDefaults() error {
if c.URI.String() == "" {
Expand All @@ -84,9 +100,14 @@ func (c *Config) CheckAndSetDefaults() error {
}

if c.Log == nil {
c.Log = logrus.WithField("gateway", c.URI.String())
c.Log = logrus.NewEntry(logrus.StandardLogger())
}

c.Log = c.Log.WithFields(logrus.Fields{
"resource": c.TargetURI,
"gateway": c.URI.String(),
})

if c.TargetName == "" {
return trace.BadParameter("missing target name")
}
Expand All @@ -103,5 +124,21 @@ func (c *Config) CheckAndSetDefaults() error {
c.TCPPortAllocator = NetTCPPortAllocator{}
}

if c.Clock == nil {
c.Clock = clockwork.NewRealClock()
}

return nil
}

// RouteToDatabase returns tlsca.RouteToDatabase based on the config of the gateway.
//
// The tlsca.RouteToDatabase.Database field is skipped, as it's an optional field and gateways can
// change their Config.TargetSubresourceName at any moment.
func (c *Config) RouteToDatabase() tlsca.RouteToDatabase {
return tlsca.RouteToDatabase{
ServiceName: c.TargetName,
Protocol: c.Protocol,
Username: c.TargetUser,
}
}
47 changes: 42 additions & 5 deletions lib/teleterm/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ func New(cfg Config) (*Gateway, error) {
return nil, trace.Wrap(err)
}

cfg.LocalPort = port

protocol, err := alpncommon.ToALPNProtocol(cfg.Protocol)
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -80,35 +82,52 @@ func New(cfg Config) (*Gateway, error) {
return nil, trace.Wrap(err)
}

localProxy, err := alpn.NewLocalProxy(alpn.LocalProxyConfig{
localProxyConfig := alpn.LocalProxyConfig{
InsecureSkipVerify: cfg.Insecure,
RemoteProxyAddr: cfg.WebProxyAddr,
Protocols: []alpncommon.Protocol{protocol},
Listener: listener,
ParentContext: closeContext,
SNI: address.Host(),
Certs: []tls.Certificate{tlsCert},
})
Clock: cfg.Clock,
}

localProxyMiddleware := &localProxyMiddleware{
log: cfg.Log,
dbRoute: cfg.RouteToDatabase(),
}

if cfg.OnExpiredCert != nil {
localProxyConfig.Middleware = localProxyMiddleware
}

localProxy, err := alpn.NewLocalProxy(localProxyConfig)
if err != nil {
return nil, trace.Wrap(err)
}

cfg.LocalPort = port

gateway := &Gateway{
cfg: &cfg,
closeContext: closeContext,
closeCancel: closeCancel,
localProxy: localProxy,
}

if cfg.OnExpiredCert != nil {
localProxyMiddleware.onExpiredCert = func(ctx context.Context) error {
err := cfg.OnExpiredCert(ctx, gateway)
return trace.Wrap(err)
}
}

ok = true
return gateway, nil
}

// NewWithLocalPort initializes a copy of an existing gateway which has all config fields identical
// to the existing gateway with the exception of the local port.
func NewWithLocalPort(gateway Gateway, port string) (*Gateway, error) {
func NewWithLocalPort(gateway *Gateway, port string) (*Gateway, error) {
if port == gateway.LocalPort() {
return nil, trace.BadParameter("port is already set to %s", port)
}
Expand Down Expand Up @@ -207,6 +226,24 @@ func (g *Gateway) CLICommand() (string, error) {
return cliCommand, nil
}

// ReloadCert loads the key pair from cfg.CertPath & cfg.KeyPath and updates the cert of the running
// local proxy. This is typically done after the cert is reissued and saved to disk.
//
// In the future, we're probably going to make this method accept the cert as an arg rather than
// reading from disk.
func (g *Gateway) ReloadCert() error {
g.cfg.Log.Debug("Reloading cert")

tlsCert, err := keys.LoadX509KeyPair(g.cfg.CertPath, g.cfg.KeyPath)
if err != nil {
return trace.Wrap(err)
}

g.localProxy.SetCerts([]tls.Certificate{tlsCert})

return nil
}

// Gateway describes local proxy that creates a gateway to the remote Teleport resource.
//
// Gateway is not safe for concurrent use in itself. However, all access to gateways is gated by
Expand Down
16 changes: 11 additions & 5 deletions lib/teleterm/gateway/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ func TestGatewayStart(t *testing.T) {
},
)
require.NoError(t, err)
t.Cleanup(func() { gateway.Close() })
t.Cleanup(func() {
if err := gateway.Close(); err != nil {
t.Logf("Ignoring error from gateway.Close() during cleanup, it appears the gateway was already closed. The error was: %s", err)
}
})
gatewayAddress := net.JoinHostPort(gateway.LocalAddress(), gateway.LocalPort())

require.NotEmpty(t, gateway.LocalPort())
Expand All @@ -91,7 +95,7 @@ func TestNewWithLocalPortStartsListenerOnNewPortIfPortIsFree(t *testing.T) {
tcpPortAllocator := gatewaytest.MockTCPPortAllocator{}
oldGateway := createAndServeGateway(t, &tcpPortAllocator)

newGateway, err := NewWithLocalPort(*oldGateway, "12345")
newGateway, err := NewWithLocalPort(oldGateway, "12345")
require.NoError(t, err)
require.Equal(t, "12345", newGateway.LocalPort())
require.Equal(t, oldGateway.URI(), newGateway.URI())
Expand All @@ -109,7 +113,7 @@ func TestNewWithLocalPortReturnsErrorIfNewPortIsOccupied(t *testing.T) {
tcpPortAllocator := gatewaytest.MockTCPPortAllocator{PortsInUse: []string{"12345"}}
gateway := createAndServeGateway(t, &tcpPortAllocator)

_, err := NewWithLocalPort(*gateway, "12345")
_, err := NewWithLocalPort(gateway, "12345")
require.ErrorContains(t, err, "address already in use")
}

Expand All @@ -119,7 +123,7 @@ func TestNewWithLocalPortReturnsErrorIfNewPortEqualsOldPort(t *testing.T) {
port := gateway.LocalPort()
expectedErrMessage := fmt.Sprintf("port is already set to %s", port)

_, err := NewWithLocalPort(*gateway, port)
_, err := NewWithLocalPort(gateway, port)
require.True(t, trace.IsBadParameter(err), "Expected err to be a BadParameter error")
require.ErrorContains(t, err, expectedErrMessage)
}
Expand Down Expand Up @@ -171,7 +175,9 @@ func serveGatewayAndBlockUntilItAcceptsConnections(t *testing.T, gateway *Gatewa
serveErr <- err
}()
t.Cleanup(func() {
gateway.Close()
if err := gateway.Close(); err != nil {
t.Logf("Ignoring error from gateway.Close() during cleanup, it appears the gateway was already closed. The error was: %s", err)
}
require.NoError(t, <-serveErr, "Gateway %s returned error from Serve()", gateway.URI())
})

Expand Down
Loading