Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 7 additions & 49 deletions lib/teleterm/apiserver/handler/handler_gateways.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

api "github.com/gravitational/teleport/lib/teleterm/api/protogen/golang/v1"
"github.com/gravitational/teleport/lib/teleterm/daemon"
"github.com/gravitational/teleport/lib/teleterm/gateway"
)

// CreateGateway creates a gateway
Expand All @@ -38,30 +37,18 @@ func (s *Handler) CreateGateway(ctx context.Context, req *api.CreateGatewayReque
return nil, trace.Wrap(err)
}

apiGateway, err := newAPIGateway(*gateway)
if err != nil {
return nil, trace.Wrap(err)
}

return apiGateway, nil
return gateway, nil
}

// ListGateways lists all gateways
func (s *Handler) ListGateways(ctx context.Context, req *api.ListGatewaysRequest) (*api.ListGatewaysResponse, error) {
gws := s.DaemonService.ListGateways()

apiGws := make([]*api.Gateway, 0, len(gws))
for _, gw := range gws {
apiGateway, err := newAPIGateway(gw)
if err != nil {
return nil, trace.Wrap(err)
}

apiGws = append(apiGws, apiGateway)
gateways, err := s.DaemonService.ListGateways()
if err != nil {
return nil, trace.Wrap(err)
}

return &api.ListGatewaysResponse{
Gateways: apiGws,
Gateways: gateways,
}, nil
}

Expand All @@ -74,25 +61,6 @@ func (s *Handler) RemoveGateway(ctx context.Context, req *api.RemoveGatewayReque
return &api.EmptyResponse{}, nil
}

func newAPIGateway(gateway gateway.Gateway) (*api.Gateway, error) {
command, err := gateway.CLICommand()
if err != nil {
return nil, trace.Wrap(err)
}

return &api.Gateway{
Uri: gateway.URI().String(),
TargetUri: gateway.TargetURI(),
TargetName: gateway.TargetName(),
TargetUser: gateway.TargetUser(),
TargetSubresourceName: gateway.TargetSubresourceName(),
Protocol: gateway.Protocol(),
LocalAddress: gateway.LocalAddress(),
LocalPort: gateway.LocalPort(),
CliCommand: command,
}, nil
}

// RestartGateway stops a gateway and starts a new with identical parameters but fresh certs,
// keeping the original URI.
func (s *Handler) RestartGateway(ctx context.Context, req *api.RestartGatewayRequest) (*api.EmptyResponse, error) {
Expand All @@ -113,12 +81,7 @@ func (s *Handler) SetGatewayTargetSubresourceName(ctx context.Context, req *api.
return nil, trace.Wrap(err)
}

apiGateway, err := newAPIGateway(*gateway)
if err != nil {
return nil, trace.Wrap(err)
}

return apiGateway, nil
return gateway, nil
}

// SetGatewayLocalPort restarts the gateway under the new port without fetching new certs.
Expand All @@ -128,10 +91,5 @@ func (s *Handler) SetGatewayLocalPort(ctx context.Context, req *api.SetGatewayLo
return nil, trace.Wrap(err)
}

apiGateway, err := newAPIGateway(*gateway)
if err != nil {
return nil, trace.Wrap(err)
}

return apiGateway, nil
return gateway, nil
}
6 changes: 3 additions & 3 deletions lib/teleterm/clusters/dbcmd_cli_command_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ import (
// DbcmdCLICommandProvider provides CLI commands for database gateways. It needs Storage to read
// fresh profile state from the disk.
type DbcmdCLICommandProvider struct {
storage StorageByResourceURI
storage ClusterGetter
execer dbcmd.Execer
}

type StorageByResourceURI interface {
type ClusterGetter interface {
GetByResourceURI(string) (*Cluster, error)
}

func NewDbcmdCLICommandProvider(storage StorageByResourceURI, execer dbcmd.Execer) DbcmdCLICommandProvider {
func NewDbcmdCLICommandProvider(storage ClusterGetter, execer dbcmd.Execer) DbcmdCLICommandProvider {
return DbcmdCLICommandProvider{
storage: storage,
execer: execer,
Expand Down
12 changes: 9 additions & 3 deletions lib/teleterm/daemon/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/sirupsen/logrus"
"google.golang.org/grpc"

"github.com/gravitational/teleport/lib/client/db/dbcmd"
"github.com/gravitational/teleport/lib/teleterm/clusters"
"github.com/gravitational/teleport/lib/teleterm/gateway"
)
Expand All @@ -30,9 +31,10 @@ type Config struct {
// Storage is a storage service that reads/writes to tsh profiles
Storage *clusters.Storage
// Log is a component logger
Log *logrus.Entry
GatewayCreator GatewayCreator
TCPPortAllocator gateway.TCPPortAllocator
Log *logrus.Entry
GatewayCreator GatewayCreator
TCPPortAllocator gateway.TCPPortAllocator
CLICommandProvider gateway.CLICommandProvider
// CreateTshdEventsClientCredsFunc lazily creates creds for the tshd events server ran by the
// Electron app. This is to ensure that the server public key is written to the disk under the
// expected location by the time we get around to creating the client.
Expand All @@ -55,6 +57,10 @@ func (c *Config) CheckAndSetDefaults() error {
c.TCPPortAllocator = gateway.NetTCPPortAllocator{}
}

if c.CLICommandProvider == nil {
c.CLICommandProvider = clusters.NewDbcmdCLICommandProvider(c.Storage, dbcmd.SystemExecer{})
}

if c.Log == nil {
c.Log = logrus.NewEntry(logrus.StandardLogger()).WithField(trace.Component, "daemon")
}
Expand Down
37 changes: 22 additions & 15 deletions lib/teleterm/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"google.golang.org/grpc"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/client/db/dbcmd"
api "github.com/gravitational/teleport/lib/teleterm/api/protogen/golang/v1"
"github.com/gravitational/teleport/lib/teleterm/clusters"
"github.com/gravitational/teleport/lib/teleterm/gateway"
Expand Down Expand Up @@ -148,7 +147,7 @@ func (s *Service) ClusterLogout(ctx context.Context, uri string) error {
}

// CreateGateway creates a gateway to given targetURI
func (s *Service) CreateGateway(ctx context.Context, params CreateGatewayParams) (*gateway.Gateway, error) {
func (s *Service) CreateGateway(ctx context.Context, params CreateGatewayParams) (*api.Gateway, error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -157,7 +156,8 @@ func (s *Service) CreateGateway(ctx context.Context, params CreateGatewayParams)
return nil, trace.Wrap(err)
}

return gateway, nil
protoGateway, err := gateway.ToProto()
return protoGateway, trace.Wrap(err)
}

type GatewayCreator interface {
Expand All @@ -166,13 +166,12 @@ type GatewayCreator interface {

// createGateway assumes that mu is already held by a public method.
func (s *Service) createGateway(ctx context.Context, params CreateGatewayParams) (*gateway.Gateway, error) {
cliCommandProvider := clusters.NewDbcmdCLICommandProvider(s.cfg.Storage, dbcmd.SystemExecer{})
clusterCreateGatewayParams := clusters.CreateGatewayParams{
TargetURI: params.TargetURI,
TargetUser: params.TargetUser,
TargetSubresourceName: params.TargetSubresourceName,
LocalPort: params.LocalPort,
CLICommandProvider: cliCommandProvider,
CLICommandProvider: s.cfg.CLICommandProvider,
TCPPortAllocator: s.cfg.TCPPortAllocator,
}

Expand Down Expand Up @@ -267,21 +266,25 @@ func (s *Service) findGateway(gatewayURI string) (*gateway.Gateway, error) {
}

// ListGateways lists gateways
func (s *Service) ListGateways() []gateway.Gateway {
func (s *Service) ListGateways() ([]*api.Gateway, error) {
s.mu.RLock()
defer s.mu.RUnlock()

gws := make([]gateway.Gateway, 0, len(s.gateways))
gws := make([]*api.Gateway, 0, len(s.gateways))
for _, gateway := range s.gateways {
gws = append(gws, *gateway)
protoGateway, err := gateway.ToProto()
if err != nil {
return nil, trace.Wrap(err)
}
gws = append(gws, protoGateway)
}

return gws
return gws, nil
}

// SetGatewayTargetSubresourceName updates the TargetSubresourceName field of a gateway stored in
// s.gateways.
func (s *Service) SetGatewayTargetSubresourceName(gatewayURI, targetSubresourceName string) (*gateway.Gateway, error) {
func (s *Service) SetGatewayTargetSubresourceName(gatewayURI, targetSubresourceName string) (*api.Gateway, error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -292,7 +295,8 @@ func (s *Service) SetGatewayTargetSubresourceName(gatewayURI, targetSubresourceN

gateway.SetTargetSubresourceName(targetSubresourceName)

return gateway, nil
protoGateway, err := gateway.ToProto()
return protoGateway, trace.Wrap(err)
}

// SetGatewayLocalPort creates a new gateway with the given port, swaps it with the old gateway
Expand All @@ -304,7 +308,7 @@ func (s *Service) SetGatewayTargetSubresourceName(gatewayURI, targetSubresourceN
// correct that mistake and choose a different port.
//
// SetGatewayLocalPort is a noop if port is equal to the existing port.
func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (*gateway.Gateway, error) {
func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (*api.Gateway, error) {
s.mu.Lock()
defer s.mu.Unlock()

Expand All @@ -314,7 +318,8 @@ func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (*gateway.Ga
}

if localPort == oldGateway.LocalPort() {
return oldGateway, nil
protoOldGateway, err := oldGateway.ToProto()
return protoOldGateway, trace.Wrap(err)
}

newGateway, err := gateway.NewWithLocalPort(oldGateway, localPort)
Expand Down Expand Up @@ -345,7 +350,8 @@ func (s *Service) SetGatewayLocalPort(gatewayURI, localPort string) (*gateway.Ga
}
}()

return newGateway, nil
protoNewGateway, err := newGateway.ToProto()
return protoNewGateway, trace.Wrap(err)
}

// GetAllServers returns a full list of nodes without pagination or sorting.
Expand Down Expand Up @@ -588,7 +594,8 @@ func (s *Service) TransferFile(ctx context.Context, request *api.FileTransferReq
// Service is the daemon service
type Service struct {
cfg *Config
mu sync.RWMutex
// mu guards gateways.
mu sync.RWMutex
// closeContext is canceled when Service is getting stopped. It is used as a context for the calls
// to the tshd events gRPC client.
closeContext context.Context
Expand Down
38 changes: 25 additions & 13 deletions lib/teleterm/daemon/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package daemon

import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -85,6 +86,13 @@ func (m *mockGatewayCreator) CreateGateway(ctx context.Context, params clusters.
return gateway, nil
}

type mockCLICommandProvider struct{}

func (m mockCLICommandProvider) GetCommand(gateway *gateway.Gateway) (string, error) {
command := fmt.Sprintf("%s/%s", gateway.TargetName(), gateway.TargetSubresourceName())
return command, nil
}

type gatewayCRUDTestContext struct {
nameToGateway map[string]*gateway.Gateway
mockGatewayCreator *mockGatewayCreator
Expand Down Expand Up @@ -115,16 +123,17 @@ func TestGatewayCRUD(t *testing.T) {
name: "ListGateways",
gatewayNamesToCreate: []string{"gateway1", "gateway2"},
testFunc: func(t *testing.T, c *gatewayCRUDTestContext, daemon *Service) {
gateways := daemon.ListGateways()
gatewayURIs := map[uri.ResourceURI]struct{}{}
protoGateways, err := daemon.ListGateways()
require.NoError(t, err)
gatewayURIs := map[string]struct{}{}

for _, gateway := range gateways {
gatewayURIs[gateway.URI()] = struct{}{}
for _, protoGateway := range protoGateways {
gatewayURIs[protoGateway.Uri] = struct{}{}
}

require.Equal(t, 2, len(gateways))
require.Contains(t, gatewayURIs, c.nameToGateway["gateway1"].URI())
require.Contains(t, gatewayURIs, c.nameToGateway["gateway2"].URI())
require.Equal(t, 2, len(protoGateways))
require.Contains(t, gatewayURIs, c.nameToGateway["gateway1"].URI().String())
require.Contains(t, gatewayURIs, c.nameToGateway["gateway2"].URI().String())
},
},
{
Expand Down Expand Up @@ -170,9 +179,9 @@ func TestGatewayCRUD(t *testing.T) {

require.Equal(t, 0, oldListener.CloseCallCount)

updatedGateway, err := daemon.SetGatewayLocalPort(oldGateway.URI().String(), "12345")
updatedProtoGateway, err := daemon.SetGatewayLocalPort(oldGateway.URI().String(), "12345")
require.NoError(t, err)
require.Equal(t, "12345", updatedGateway.LocalPort())
require.Equal(t, "12345", updatedProtoGateway.LocalPort)
updatedGatewayAddress := c.mockTCPPortAllocator.RecentListener().RealAddr().String()

// Check if the restarted gateway is still available under the same URI.
Expand Down Expand Up @@ -241,23 +250,26 @@ func TestGatewayCRUD(t *testing.T) {
require.NoError(t, err)

daemon, err := New(Config{
Storage: storage,
GatewayCreator: mockGatewayCreator,
TCPPortAllocator: tt.tcpPortAllocator,
Storage: storage,
GatewayCreator: mockGatewayCreator,
TCPPortAllocator: tt.tcpPortAllocator,
CLICommandProvider: &mockCLICommandProvider{},
})
require.NoError(t, err)

nameToGateway := make(map[string]*gateway.Gateway, len(tt.gatewayNamesToCreate))

for _, gatewayName := range tt.gatewayNamesToCreate {
gatewayName := gatewayName
gateway, err := daemon.CreateGateway(context.Background(), CreateGatewayParams{
protoGateway, err := daemon.CreateGateway(context.Background(), CreateGatewayParams{
TargetURI: uri.NewClusterURI("foo").AppendDB(gatewayName).String(),
TargetUser: "alice",
TargetSubresourceName: "",
LocalPort: "",
})
require.NoError(t, err)
gateway, err := daemon.findGateway(protoGateway.Uri)
require.NoError(t, err)

nameToGateway[gatewayName] = gateway
}
Expand Down
Loading