diff --git a/lib/tbot/bot/service.go b/lib/tbot/bot/service.go deleted file mode 100644 index d2e8ae5b0973b..0000000000000 --- a/lib/tbot/bot/service.go +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package bot - -import "context" - -// Service is a long-running sub-component of tbot. -type Service interface { - // String returns a human-readable name for the service that can be used - // in logging. It should identify the type of the service and any top - // level configuration that could distinguish it from a same-type service. - String() string - // Run starts the service and blocks until the service exits. It should - // return a nil error if the service exits successfully and an error - // if it is unable to proceed. It should exit gracefully if the context - // is canceled. - Run(ctx context.Context) error -} - -// OneShotService is a [Service] that offers a mode in which it runs a single -// time and then exits. This aligns with the `--oneshot` mode of tbot. -type OneShotService interface { - Service - // OneShot runs the service once and then exits. It should return a nil - // error if the service exits successfully and an error if it is unable - // to proceed. It should exit gracefully if the context is canceled. - OneShot(ctx context.Context) error -} diff --git a/lib/tbot/config/config.go b/lib/tbot/config/config.go index e27ccd8eade48..37800201c4206 100644 --- a/lib/tbot/config/config.go +++ b/lib/tbot/config/config.go @@ -270,7 +270,7 @@ type BotConfig struct { Onboarding OnboardingConfig `yaml:"onboarding,omitempty"` Storage *StorageConfig `yaml:"storage,omitempty"` Outputs Outputs `yaml:"outputs,omitempty"` - Services Services `yaml:"services,omitempty"` + Services ServiceConfigs `yaml:"services,omitempty"` Debug bool `yaml:"debug"` AuthServer string `yaml:"auth_server,omitempty"` @@ -374,6 +374,13 @@ func (conf *BotConfig) CheckAndSetDefaults() error { } } + // Validate configured services + for i, service := range conf.Services { + if err := service.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err, "validating service[%d]", i) + } + } + if conf.CertificateTTL == 0 { conf.CertificateTTL = DefaultCertificateTTL } @@ -418,11 +425,17 @@ func (conf *BotConfig) CheckAndSetDefaults() error { return nil } -// Services assists polymorphic unmarshaling of a slice of Services. -type Services []bot.Service +// ServiceConfig is an interface over the various service configurations. +type ServiceConfig interface { + Type() string + CheckAndSetDefaults() error +} + +// ServiceConfigs assists polymorphic unmarshaling of a slice of ServiceConfigs. +type ServiceConfigs []ServiceConfig -func (o *Services) UnmarshalYAML(node *yaml.Node) error { - var out []bot.Service +func (o *ServiceConfigs) UnmarshalYAML(node *yaml.Node) error { + var out []ServiceConfig for _, node := range node.Content { header := struct { Type string `yaml:"type"` @@ -438,6 +451,12 @@ func (o *Services) UnmarshalYAML(node *yaml.Node) error { return trace.Wrap(err) } out = append(out, v) + case DatabaseTunnelServiceType: + v := &DatabaseTunnelService{} + if err := node.Decode(v); err != nil { + return trace.Wrap(err) + } + out = append(out, v) default: return trace.BadParameter("unrecognized service type (%s)", header.Type) } diff --git a/lib/tbot/config/config_test.go b/lib/tbot/config/config_test.go index e784a470bde0e..f328819e813c3 100644 --- a/lib/tbot/config/config_test.go +++ b/lib/tbot/config/config_test.go @@ -248,7 +248,7 @@ func TestBotConfig_YAML(t *testing.T) { }, }, }, - Services: []bot.Service{ + Services: []ServiceConfig{ &ExampleService{ Message: "llama", }, diff --git a/lib/tbot/config/output_client_credential.go b/lib/tbot/config/output_client_credential.go index e2002b456a6b3..e4f37d2ea9d14 100644 --- a/lib/tbot/config/output_client_credential.go +++ b/lib/tbot/config/output_client_credential.go @@ -94,6 +94,16 @@ func (o *UnstableClientCredentialOutput) SSHClientConfig() (*ssh.ClientConfig, e return o.facade.SSHClientConfig() } +// Facade returns the underlying facade +func (o *UnstableClientCredentialOutput) Facade() (*identity.Facade, error) { + o.mu.Lock() + defer o.mu.Unlock() + if o.facade == nil { + return nil, trace.BadParameter("credentials not yet ready") + } + return o.facade, nil +} + // Render implements the Destination interface and is called regularly by the // bot with new credentials. Render passes these credentials down to the // underlying facade so that they can be used in TLS/SSH configs. diff --git a/lib/tbot/config/service_database_tunnel.go b/lib/tbot/config/service_database_tunnel.go new file mode 100644 index 0000000000000..4b0497f7a5857 --- /dev/null +++ b/lib/tbot/config/service_database_tunnel.go @@ -0,0 +1,82 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package config + +import ( + "net/url" + + "github.com/gravitational/trace" + "gopkg.in/yaml.v3" +) + +const DatabaseTunnelServiceType = "database-tunnel" + +// DatabaseTunnelService opens an authenticated tunnel for Database Access. +type DatabaseTunnelService struct { + // Listen is the address on which database tunnel should listen. Example: + // - "tcp://127.0.0.1:3306" + // - "tcp://0.0.0.0:3306 + Listen string `yaml:"listen"` + // Roles is the list of roles to request for the tunnel. + // If empty, it defaults to all the bot's roles. + Roles []string `yaml:"roles,omitempty"` + // Service is the service name of the Teleport database. Generally this is + // the name of the Teleport resource. This field is required for all types + // of database. + Service string `yaml:"service"` + // Database is the name of the database to proxy to. + Database string `yaml:"database"` + // Username is the database username to proxy as. + Username string `yaml:"username"` +} + +func (s *DatabaseTunnelService) Type() string { + return DatabaseTunnelServiceType +} + +func (s *DatabaseTunnelService) MarshalYAML() (interface{}, error) { + type raw DatabaseTunnelService + return withTypeHeader((*raw)(s), DatabaseTunnelServiceType) +} + +func (s *DatabaseTunnelService) UnmarshalYAML(node *yaml.Node) error { + // Alias type to remove UnmarshalYAML to avoid recursion + type raw DatabaseTunnelService + if err := node.Decode((*raw)(s)); err != nil { + return trace.Wrap(err) + } + return nil +} + +func (s *DatabaseTunnelService) CheckAndSetDefaults() error { + switch { + case s.Listen == "": + return trace.BadParameter("listen: should not be empty") + case s.Service == "": + return trace.BadParameter("service: should not be empty") + case s.Database == "": + return trace.BadParameter("database: should not be empty") + case s.Username == "": + return trace.BadParameter("username: should not be empty") + } + if _, err := url.Parse(s.Listen); err != nil { + return trace.Wrap(err, "parsing listen") + } + return nil +} diff --git a/lib/tbot/config/service_database_tunnel_test.go b/lib/tbot/config/service_database_tunnel_test.go new file mode 100644 index 0000000000000..eb6303d8f0153 --- /dev/null +++ b/lib/tbot/config/service_database_tunnel_test.go @@ -0,0 +1,108 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package config + +import "testing" + +func TestDatabaseTunnelService_YAML(t *testing.T) { + t.Parallel() + + tests := []testYAMLCase[DatabaseTunnelService]{ + { + name: "full", + in: DatabaseTunnelService{ + Listen: "tcp://0.0.0.0:3621", + Roles: []string{"role1", "role2"}, + Service: "service", + Database: "database", + Username: "username", + }, + }, + } + testYAML(t, tests) +} + +func TestDatabaseTunnelService_CheckAndSetDefaults(t *testing.T) { + t.Parallel() + + tests := []testCheckAndSetDefaultsCase[*DatabaseTunnelService]{ + { + name: "valid", + in: func() *DatabaseTunnelService { + return &DatabaseTunnelService{ + Listen: "tcp://0.0.0.0:3621", + Roles: []string{"role1", "role2"}, + Service: "service", + Database: "database", + Username: "username", + } + }, + wantErr: "", + }, + { + name: "missing listen", + in: func() *DatabaseTunnelService { + return &DatabaseTunnelService{ + Roles: []string{"role1", "role2"}, + Service: "service", + Database: "database", + Username: "username", + } + }, + wantErr: "listen: should not be empty", + }, + { + name: "missing service", + in: func() *DatabaseTunnelService { + return &DatabaseTunnelService{ + Listen: "tcp://0.0.0.0:3621", + Roles: []string{"role1", "role2"}, + Database: "database", + Username: "username", + } + }, + wantErr: "service: should not be empty", + }, + { + name: "missing database", + in: func() *DatabaseTunnelService { + return &DatabaseTunnelService{ + Listen: "tcp://0.0.0.0:3621", + Roles: []string{"role1", "role2"}, + Service: "service", + Username: "username", + } + }, + wantErr: "database: should not be empty", + }, + { + name: "missing username", + in: func() *DatabaseTunnelService { + return &DatabaseTunnelService{ + Listen: "tcp://0.0.0.0:3621", + Roles: []string{"role1", "role2"}, + Service: "service", + Database: "database", + } + }, + wantErr: "username: should not be empty", + }, + } + testCheckAndSetDefaults(t, tests) +} diff --git a/lib/tbot/config/service_example.go b/lib/tbot/config/service_example.go index f6d1172f59d9d..471b53bf5c1cd 100644 --- a/lib/tbot/config/service_example.go +++ b/lib/tbot/config/service_example.go @@ -19,10 +19,6 @@ package config import ( - "context" - "fmt" - "time" - "github.com/gravitational/trace" "gopkg.in/yaml.v3" ) @@ -36,17 +32,6 @@ type ExampleService struct { Message string `yaml:"message"` } -func (s *ExampleService) Run(ctx context.Context) error { - for { - select { - case <-ctx.Done(): - return nil - case <-time.After(time.Second * 5): - fmt.Println("Example Service prints message:", s.Message) - } - } -} - func (s *ExampleService) Type() string { return ExampleServiceType } @@ -65,6 +50,9 @@ func (s *ExampleService) UnmarshalYAML(node *yaml.Node) error { return nil } -func (s *ExampleService) String() string { - return fmt.Sprintf("%s:%s", ExampleServiceType, s.Message) +func (s *ExampleService) CheckAndSetDefaults() error { + if s.Message == "" { + return trace.BadParameter("message: should not be empty") + } + return nil } diff --git a/lib/tbot/config/testdata/TestDatabaseTunnelService_YAML/full.golden b/lib/tbot/config/testdata/TestDatabaseTunnelService_YAML/full.golden new file mode 100644 index 0000000000000..f56e8dfa7c3c5 --- /dev/null +++ b/lib/tbot/config/testdata/TestDatabaseTunnelService_YAML/full.golden @@ -0,0 +1,8 @@ +type: database-tunnel +listen: tcp://0.0.0.0:3621 +roles: + - role1 + - role2 +service: service +database: database +username: username diff --git a/lib/tbot/database_access.go b/lib/tbot/database_access.go new file mode 100644 index 0000000000000..cce260d6a4c88 --- /dev/null +++ b/lib/tbot/database_access.go @@ -0,0 +1,96 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tbot + +import ( + "context" + + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + + apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" + libdefaults "github.com/gravitational/teleport/lib/defaults" +) + +func getDatabase(ctx context.Context, clt *auth.Client, name string) (types.Database, error) { + ctx, span := tracer.Start(ctx, "getDatabase") + defer span.End() + + servers, err := apiclient.GetAllResources[types.DatabaseServer](ctx, clt, &proto.ListResourcesRequest{ + Namespace: defaults.Namespace, + ResourceType: types.KindDatabaseServer, + PredicateExpression: makeNameOrDiscoveredNamePredicate(name), + Limit: int32(defaults.DefaultChunkSize), + }) + if err != nil { + return nil, trace.Wrap(err) + } + + var databases []types.Database + for _, server := range servers { + databases = append(databases, server.GetDatabase()) + } + + databases = types.DeduplicateDatabases(databases) + db, err := chooseOneDatabase(databases, name) + return db, trace.Wrap(err) +} + +func getRouteToDatabase( + ctx context.Context, + log logrus.FieldLogger, + client *auth.Client, + service string, + username string, + database string, +) (proto.RouteToDatabase, error) { + ctx, span := tracer.Start(ctx, "getRouteToDatabase") + defer span.End() + + if service == "" { + return proto.RouteToDatabase{}, nil + } + + db, err := getDatabase(ctx, client, service) + if err != nil { + return proto.RouteToDatabase{}, trace.Wrap(err) + } + // make sure the output matches the fully resolved db name, since it may + // have been just a "discovered name". + service = db.GetName() + if db.GetProtocol() == libdefaults.ProtocolMongoDB && username == "" { + // This isn't strictly a runtime error so killing the process seems + // wrong. We'll just loudly warn about it. + log.Errorf("Database `username` field for %q is unset but is required for MongoDB databases.", service) + } else if db.GetProtocol() == libdefaults.ProtocolRedis && username == "" { + // Per tsh's lead, fall back to the default username. + username = libdefaults.DefaultRedisUsername + } + + return proto.RouteToDatabase{ + ServiceName: service, + Protocol: db.GetProtocol(), + Database: database, + Username: username, + }, nil +} diff --git a/lib/tbot/service_database_tunnel.go b/lib/tbot/service_database_tunnel.go new file mode 100644 index 0000000000000..7cd1c7ce67929 --- /dev/null +++ b/lib/tbot/service_database_tunnel.go @@ -0,0 +1,297 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tbot + +import ( + "context" + "crypto/tls" + "fmt" + "net" + "net/url" + + "github.com/gravitational/trace" + "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/srv/alpnproxy" + "github.com/gravitational/teleport/lib/srv/alpnproxy/common" + "github.com/gravitational/teleport/lib/tbot/config" + "github.com/gravitational/teleport/lib/tbot/identity" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +var _ alpnproxy.LocalProxyMiddleware = (*alpnProxyMiddleware)(nil) + +type alpnProxyMiddleware struct { + onNewConnection func(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error + onStart func(ctx context.Context, lp *alpnproxy.LocalProxy) error +} + +func (a alpnProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error { + if a.onNewConnection != nil { + return a.onNewConnection(ctx, lp, conn) + } + return nil +} + +func (a alpnProxyMiddleware) OnStart(ctx context.Context, lp *alpnproxy.LocalProxy) error { + if a.onStart != nil { + return a.onStart(ctx, lp) + } + return nil +} + +// DatabaseTunnelService is a service that listens on a local port and forwards +// connections to a remote database service. It is an authenticating tunnel and +// will automatically issue and renew certificates as needed. +type DatabaseTunnelService struct { + botCfg *config.BotConfig + cfg *config.DatabaseTunnelService + proxyPingCache *proxyPingCache + log logrus.FieldLogger + resolver reversetunnelclient.Resolver + botClient *auth.Client + getBotIdentity getBotIdentityFn +} + +// buildLocalProxyConfig initializes the service, fetching any initial information and setting +// up the localproxy. +func (s *DatabaseTunnelService) buildLocalProxyConfig(ctx context.Context) (lpCfg alpnproxy.LocalProxyConfig, err error) { + ctx, span := tracer.Start(ctx, "DatabaseTunnelService/buildLocalProxyConfig") + defer span.End() + + // Determine the roles to use for the impersonated db access user. We fall + // back to all the roles the bot has if none are configured. + roles := s.cfg.Roles + if len(roles) == 0 { + roles, err = fetchDefaultRoles(ctx, s.botClient, s.getBotIdentity()) + if err != nil { + return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "fetching default roles") + } + s.log.WithField("roles", roles).Debug("No roles configured, using all roles available.") + } + + proxyPing, err := s.proxyPingCache.ping(ctx) + if err != nil { + return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "pinging proxy") + } + proxyAddr := proxyPing.Proxy.SSH.PublicAddr + + // Fetch information about the database and then issue the initial + // certificate. We issue the initial certificate to allow us to fail faster. + // We cache the routeToDatabase as these will not change during the lifetime + // of the service and this reduces the time needed to issue a new + // certificate. + s.log.Debug("Determining route to database.") + routeToDatabase, err := s.getRouteToDatabaseWithImpersonation(ctx, roles) + if err != nil { + return alpnproxy.LocalProxyConfig{}, trace.Wrap(err) + } + s.log.WithFields(logrus.Fields{ + "serviceName": routeToDatabase.ServiceName, + "protocol": routeToDatabase.Protocol, + "database": routeToDatabase.Database, + "username": routeToDatabase.Username, + }).Debug("Identified route to database.") + + s.log.Debug("Issuing initial certificate for local proxy.") + dbCert, err := s.issueCert(ctx, routeToDatabase, roles) + if err != nil { + return alpnproxy.LocalProxyConfig{}, trace.Wrap(err) + } + s.log.Debug("Issued initial certificate for local proxy.") + + middleware := alpnProxyMiddleware{ + onNewConnection: func(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error { + ctx, span := tracer.Start(ctx, "DatabaseTunnelService/OnNewConnection") + defer span.End() + + // Check if the certificate needs reissuing, if so, reissue. + if err := lp.CheckDBCerts(tlsca.RouteToDatabase{ + ServiceName: routeToDatabase.ServiceName, + Protocol: routeToDatabase.Protocol, + Database: routeToDatabase.Database, + Username: routeToDatabase.Username, + }); err != nil { + s.log.WithField("reason", err.Error()).Info("Certificate for tunnel needs reissuing.") + cert, err := s.issueCert(ctx, routeToDatabase, roles) + if err != nil { + return trace.Wrap(err, "issuing cert") + } + lp.SetCerts([]tls.Certificate{*cert}) + } + return nil + }, + } + + alpnProtocol, err := common.ToALPNProtocol(routeToDatabase.Protocol) + if err != nil { + return alpnproxy.LocalProxyConfig{}, trace.Wrap(err) + + } + lpConfig := alpnproxy.LocalProxyConfig{ + Middleware: middleware, + + RemoteProxyAddr: proxyAddr, + ParentContext: ctx, + Protocols: []common.Protocol{alpnProtocol}, + Certs: []tls.Certificate{*dbCert}, + InsecureSkipVerify: s.botCfg.Insecure, + } + if client.IsALPNConnUpgradeRequired( + ctx, + proxyAddr, + s.botCfg.Insecure, + ) { + lpConfig.ALPNConnUpgradeRequired = true + // If ALPN Conn Upgrade will be used, we need to set the cluster CAs + // to validate the Proxy's auth issued host cert. + lpConfig.RootCAs = s.getBotIdentity().TLSCAPool + } + + return lpConfig, nil +} + +func (s *DatabaseTunnelService) Run(ctx context.Context) error { + ctx, span := tracer.Start(ctx, "DatabaseTunnelService/Run") + defer span.End() + + listenUrl, err := url.Parse(s.cfg.Listen) + if err != nil { + return trace.Wrap(err, "parsing listen url") + } + + s.log.WithField("address", listenUrl.String()).Debug("Opening listener for database tunnel.") + l, err := net.Listen("tcp", listenUrl.Host) + if err != nil { + return trace.Wrap(err, "opening listener") + } + defer func() { + if err := l.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) { + s.log.WithError(err).Error("Failed to close listener") + } + }() + + lpCfg, err := s.buildLocalProxyConfig(ctx) + if err != nil { + return trace.Wrap(err, "building local proxy config") + } + lpCfg.Listener = l + + lp, err := alpnproxy.NewLocalProxy(lpCfg) + if err != nil { + return trace.Wrap(err, "creating local proxy") + } + defer func() { + if err := lp.Close(); err != nil { + s.log.WithError(err).Error("Failed to close local proxy") + } + }() + // Closed further down. + + // lp.Start will block and continues to block until lp.Close() is called. + // Despite taking a context, it will not exit until the first connection is + // made after the context is canceled. + var errCh = make(chan error, 1) + go func() { + errCh <- lp.Start(ctx) + }() + s.log.WithField("address", l.Addr().String()).Info("Listening for connections.") + + select { + case <-ctx.Done(): + return nil + case err := <-errCh: + return trace.Wrap(err, "local proxy failed") + } +} + +// getRouteToDatabaseWithImpersonation fetches the route to the database with +// impersonation of roles. This ensures that the user's selected roles actually +// grant access to the database. +func (s *DatabaseTunnelService) getRouteToDatabaseWithImpersonation(ctx context.Context, roles []string) (proto.RouteToDatabase, error) { + ctx, span := tracer.Start(ctx, "DatabaseTunnelService/getRouteToDatabaseWithImpersonation") + defer span.End() + + impersonatedIdentity, err := generateIdentity( + ctx, + s.botClient, + s.getBotIdentity(), + roles, + s.botCfg.CertificateTTL, + nil, + ) + if err != nil { + return proto.RouteToDatabase{}, trace.Wrap(err) + } + + impersonatedClient, err := clientForFacade( + ctx, + s.log, + s.botCfg, + identity.NewFacade(s.botCfg.FIPS, s.botCfg.Insecure, impersonatedIdentity), + s.resolver, + ) + if err != nil { + return proto.RouteToDatabase{}, trace.Wrap(err) + } + defer func() { + if err := impersonatedClient.Close(); err != nil { + s.log.WithError(err).Error("Failed to close impersonated client.") + } + }() + + return getRouteToDatabase(ctx, s.log, impersonatedClient, s.cfg.Service, s.cfg.Username, s.cfg.Database) +} + +func (s *DatabaseTunnelService) issueCert( + ctx context.Context, + route proto.RouteToDatabase, + roles []string, +) (*tls.Certificate, error) { + ctx, span := tracer.Start(ctx, "DatabaseTunnelService/issueCert") + defer span.End() + + s.log.Debug("Requesting issuance of certificate for tunnel proxy.") + ident, err := generateIdentity( + ctx, + s.botClient, + s.getBotIdentity(), + roles, + s.botCfg.CertificateTTL, + func(req *proto.UserCertsRequest) { + req.RouteToDatabase = route + }) + if err != nil { + return nil, trace.Wrap(err) + } + s.log.Info("Certificate issued for tunnel proxy.") + + return ident.TLSCert, nil +} + +// String returns a human-readable string that can uniquely identify the +// service. +func (s *DatabaseTunnelService) String() string { + return fmt.Sprintf("%s:%s:%s", config.DatabaseTunnelServiceType, s.cfg.Listen, s.cfg.Service) +} diff --git a/lib/tbot/service_example.go b/lib/tbot/service_example.go new file mode 100644 index 0000000000000..c550c52d9f24f --- /dev/null +++ b/lib/tbot/service_example.go @@ -0,0 +1,50 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package tbot + +import ( + "context" + "fmt" + "time" + + "github.com/gravitational/teleport/lib/tbot/config" +) + +// ExampleService is a temporary example service for testing purposes. It is +// not intended to be used and exists to demonstrate how a user configurable +// service integrates with the tbot service manager. +type ExampleService struct { + cfg *config.ExampleService + Message string `yaml:"message"` +} + +func (s *ExampleService) Run(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(time.Second * 5): + fmt.Println("Example Service prints message:", s.Message) + } + } +} + +func (s *ExampleService) String() string { + return fmt.Sprintf("%s:%s", config.ExampleServiceType, s.Message) +} diff --git a/lib/tbot/service_outputs.go b/lib/tbot/service_outputs.go index ec68b3a7ddc33..d7e7775610869 100644 --- a/lib/tbot/service_outputs.go +++ b/lib/tbot/service_outputs.go @@ -38,7 +38,6 @@ import ( "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/native" - libdefaults "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/tbot/config" @@ -56,6 +55,8 @@ const renewalRetryLimit = 5 type outputsService struct { log logrus.FieldLogger reloadBroadcaster *channelBroadcaster + proxyPingCache *proxyPingCache + authPingCache *authPingCache botClient *auth.Client getBotIdentity getBotIdentityFn cfg *config.BotConfig @@ -86,16 +87,16 @@ func (s *outputsService) renewOutputs( // create a cache shared across outputs so they don't hammer the auth // server with similar requests drc := &outputRenewalCache{ - client: s.botClient, - cfg: s.cfg, + proxyPingCache: s.proxyPingCache, + authPingCache: s.authPingCache, + client: s.botClient, + cfg: s.cfg, } // Determine the default role list based on the bot role. The role's // name should match the certificate's Key ID (user and role names // should all match bot-$name) - botIdentity := s.getBotIdentity() - botResourceName := botIdentity.X509Cert.Subject.CommonName - defaultRoles, err := fetchDefaultRoles(ctx, s.botClient, botResourceName) + defaultRoles, err := fetchDefaultRoles(ctx, s.botClient, s.getBotIdentity()) if err != nil { s.log.WithError(err).Warnf("Unable to determine default roles, no roles will be requested if unspecified") defaultRoles = []string{} @@ -124,7 +125,7 @@ func (s *outputsService) renewOutputs( } impersonatedIdentity, impersonatedClient, err := s.generateImpersonatedIdentity( - ctx, s.botClient, botIdentity, output, defaultRoles, + ctx, s.botClient, s.getBotIdentity(), output, defaultRoles, ) if err != nil { return trace.Wrap(err, "generating impersonated certs for output: %s", output) @@ -286,15 +287,15 @@ type identityConfigurator = func(req *proto.UserCertsRequest) // impersonated identity that already has the relevant permissions, much like // `tsh (app|db|kube) login` is already used to generate an additional set of // certs. -func (s *outputsService) generateIdentity( +func generateIdentity( ctx context.Context, client *auth.Client, currentIdentity *identity.Identity, - output config.Output, - defaultRoles []string, + roles []string, + ttl time.Duration, configurator identityConfigurator, ) (*identity.Identity, error) { - ctx, span := tracer.Start(ctx, "outputsService/generateIdentity") + ctx, span := tracer.Start(ctx, "generateIdentity") defer span.End() // TODO: enforce expiration > renewal period (by what margin?) @@ -309,19 +310,11 @@ func (s *outputsService) generateIdentity( return nil, trace.Wrap(err) } - var roleRequests []string - if roles := output.GetRoles(); len(roles) > 0 { - roleRequests = roles - } else { - s.log.Debugf("Output specified no roles, defaults will be requested: %v", defaultRoles) - roleRequests = defaultRoles - } - req := proto.UserCertsRequest{ PublicKey: publicKey, Username: currentIdentity.X509Cert.Subject.CommonName, - Expires: time.Now().Add(s.cfg.CertificateTTL), - RoleRequests: roleRequests, + Expires: time.Now().Add(ttl), + RoleRequests: roles, RouteToCluster: currentIdentity.ClusterName, // Make sure to specify this is an impersonated cert request. If unset, @@ -380,64 +373,6 @@ func (s *outputsService) generateIdentity( return newIdentity, nil } -func getDatabase(ctx context.Context, clt *auth.Client, name string) (types.Database, error) { - ctx, span := tracer.Start(ctx, "getDatabase") - defer span.End() - - servers, err := apiclient.GetAllResources[types.DatabaseServer](ctx, clt, &proto.ListResourcesRequest{ - Namespace: defaults.Namespace, - ResourceType: types.KindDatabaseServer, - PredicateExpression: makeNameOrDiscoveredNamePredicate(name), - Limit: int32(defaults.DefaultChunkSize), - }) - if err != nil { - return nil, trace.Wrap(err) - } - - var databases []types.Database - for _, server := range servers { - databases = append(databases, server.GetDatabase()) - } - - databases = types.DeduplicateDatabases(databases) - db, err := chooseOneDatabase(databases, name) - return db, trace.Wrap(err) -} - -func (s *outputsService) getRouteToDatabase(ctx context.Context, client *auth.Client, output *config.DatabaseOutput) (proto.RouteToDatabase, error) { - ctx, span := tracer.Start(ctx, "outputsService/getRouteToDatabase") - defer span.End() - - if output.Service == "" { - return proto.RouteToDatabase{}, nil - } - - db, err := getDatabase(ctx, client, output.Service) - if err != nil { - return proto.RouteToDatabase{}, trace.Wrap(err) - } - // make sure the output matches the fully resolved db name, since it may - // have been just a "discovered name". - output.Service = db.GetName() - - username := output.Username - if db.GetProtocol() == libdefaults.ProtocolMongoDB && username == "" { - // This isn't strictly a runtime error so killing the process seems - // wrong. We'll just loudly warn about it. - s.log.Errorf("Database `username` field for %q is unset but is required for MongoDB databases.", output.Service) - } else if db.GetProtocol() == libdefaults.ProtocolRedis && username == "" { - // Per tsh's lead, fall back to the default username. - username = libdefaults.DefaultRedisUsername - } - - return proto.RouteToDatabase{ - ServiceName: output.Service, - Protocol: db.GetProtocol(), - Database: output.Database, - Username: username, - }, nil -} - func getKubeCluster(ctx context.Context, clt *auth.Client, name string) (types.KubeCluster, error) { ctx, span := tracer.Start(ctx, "getKubeCluster") defer span.End() @@ -534,8 +469,14 @@ func (s *outputsService) generateImpersonatedIdentity( ctx, span := tracer.Start(ctx, "outputsService/generateImpersonatedIdentity") defer span.End() - impersonatedIdentity, err = s.generateIdentity( - ctx, botClient, botIdentity, output, defaultRoles, nil, + roles := output.GetRoles() + if len(roles) == 0 { + s.log.Debugf("Output specified no roles, defaults will be requested: %v", defaultRoles) + roles = defaultRoles + } + + impersonatedIdentity, err = generateIdentity( + ctx, botClient, botIdentity, roles, s.cfg.CertificateTTL, nil, ) if err != nil { return nil, nil, trace.Wrap(err) @@ -564,16 +505,29 @@ func (s *outputsService) generateImpersonatedIdentity( return impersonatedIdentity, impersonatedClient, nil } - routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) { - req.RouteToCluster = output.Cluster - }, + routedIdentity, err := generateIdentity( + ctx, + botClient, + impersonatedIdentity, + roles, + s.cfg.CertificateTTL, + func(req *proto.UserCertsRequest) { + req.RouteToCluster = output.Cluster + }, ) if err != nil { return nil, nil, trace.Wrap(err) } return routedIdentity, impersonatedClient, nil case *config.DatabaseOutput: - route, err := s.getRouteToDatabase(ctx, impersonatedClient, output) + route, err := getRouteToDatabase( + ctx, + s.log, + impersonatedClient, + output.Service, + output.Username, + output.Database, + ) if err != nil { return nil, nil, trace.Wrap(err) } @@ -582,9 +536,16 @@ func (s *outputsService) generateImpersonatedIdentity( // so we'll request the database access identity using the main bot // identity (having gathered the necessary info for RouteToDatabase // using the correct impersonated unroutedIdentity.) - routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) { - req.RouteToDatabase = route - }) + routedIdentity, err := generateIdentity( + ctx, + botClient, + impersonatedIdentity, + roles, + s.cfg.CertificateTTL, + func(req *proto.UserCertsRequest) { + req.RouteToDatabase = route + }, + ) if err != nil { return nil, nil, trace.Wrap(err) } @@ -603,9 +564,16 @@ func (s *outputsService) generateImpersonatedIdentity( // Note: the Teleport server does attempt to verify k8s cluster names // and will fail to generate certs if the cluster doesn't exist or is // offline. - routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) { - req.KubernetesCluster = output.KubernetesCluster - }) + routedIdentity, err := generateIdentity( + ctx, + botClient, + impersonatedIdentity, + roles, + s.cfg.CertificateTTL, + func(req *proto.UserCertsRequest) { + req.KubernetesCluster = output.KubernetesCluster + }, + ) if err != nil { return nil, nil, trace.Wrap(err) } @@ -619,9 +587,16 @@ func (s *outputsService) generateImpersonatedIdentity( return nil, nil, trace.Wrap(err) } - routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) { - req.RouteToApp = routeToApp - }) + routedIdentity, err := generateIdentity( + ctx, + botClient, + impersonatedIdentity, + roles, + s.cfg.CertificateTTL, + func(req *proto.UserCertsRequest) { + req.RouteToApp = routeToApp + }, + ) if err != nil { return nil, nil, trace.Wrap(err) } @@ -640,8 +615,8 @@ func (s *outputsService) generateImpersonatedIdentity( // fetchDefaultRoles requests the bot's own role from the auth server and // extracts its full list of allowed roles. -func fetchDefaultRoles(ctx context.Context, roleGetter services.RoleGetter, botRole string) ([]string, error) { - role, err := roleGetter.GetRole(ctx, botRole) +func fetchDefaultRoles(ctx context.Context, roleGetter services.RoleGetter, identity *identity.Identity) ([]string, error) { + role, err := roleGetter.GetRole(ctx, identity.X509Cert.Subject.CommonName) if err != nil { return nil, trace.Wrap(err) } @@ -655,14 +630,14 @@ func fetchDefaultRoles(ctx context.Context, roleGetter services.RoleGetter, botR // requests for the same information. This is shared between all of the // outputs. type outputRenewalCache struct { - client *auth.Client + client *auth.Client + cfg *config.BotConfig + proxyPingCache *proxyPingCache + authPingCache *authPingCache - cfg *config.BotConfig - mu sync.Mutex + mu sync.Mutex // These are protected by getter/setters with mutex locks - _cas map[types.CertAuthType][]types.CertAuthority - _authPong *proto.PingResponse - _proxyPong *webclient.PingResponse + _cas map[types.CertAuthType][]types.CertAuthority } func (orc *outputRenewalCache) getCertAuthorities( @@ -694,71 +669,23 @@ func (orc *outputRenewalCache) GetCertAuthorities( return orc.getCertAuthorities(ctx, caType) } -func (orc *outputRenewalCache) authPing(ctx context.Context) (*proto.PingResponse, error) { - if orc._authPong != nil { - return orc._authPong, nil - } - - pong, err := orc.client.Ping(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - orc._authPong = &pong - - return &pong, nil -} - // AuthPing pings the auth server and returns the (possibly cached) response. func (orc *outputRenewalCache) AuthPing(ctx context.Context) (*proto.PingResponse, error) { - orc.mu.Lock() - defer orc.mu.Unlock() - return orc.authPing(ctx) -} - -func (orc *outputRenewalCache) proxyPing(ctx context.Context) (*webclient.PingResponse, error) { - if orc._proxyPong != nil { - return orc._proxyPong, nil - } - - // Determine the Proxy address to use. - addr, addrKind := orc.cfg.Address() - switch addrKind { - case config.AddressKindAuth: - // If the address is an auth address, ping auth to determine proxy addr. - authPong, err := orc.authPing(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - addr = authPong.ProxyPublicAddr - case config.AddressKindProxy: - // If the address is a proxy address, use it directly. - default: - return nil, trace.BadParameter("unsupported address kind: %v", addrKind) - } - - // We use find instead of Ping as it's less resource intense and we can - // ping the AuthServer directly for its configuration if necessary. - proxyPong, err := webclient.Find(&webclient.Config{ - Context: ctx, - ProxyAddr: addr, - Insecure: orc.cfg.Insecure, - }) + res, err := orc.authPingCache.ping(ctx) if err != nil { return nil, trace.Wrap(err) - } - - orc._proxyPong = proxyPong - return proxyPong, nil + } + return &res, nil } // ProxyPing returns a (possibly cached) ping response from the Teleport proxy. -// Note that it relies on the auth server being configured with a sane proxy -// public address. func (orc *outputRenewalCache) ProxyPing(ctx context.Context) (*webclient.PingResponse, error) { - orc.mu.Lock() - defer orc.mu.Unlock() - return orc.proxyPing(ctx) + res, err := orc.proxyPingCache.ping(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return res, nil } // Config returns the bots config. diff --git a/lib/tbot/tbot.go b/lib/tbot/tbot.go index a893cdee0e695..b0a0355aebec1 100644 --- a/lib/tbot/tbot.go +++ b/lib/tbot/tbot.go @@ -29,6 +29,7 @@ import ( "google.golang.org/grpc" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/client/webclient" "github.com/gravitational/teleport/api/metadata" "github.com/gravitational/teleport/api/types" @@ -36,7 +37,6 @@ import ( "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/reversetunnelclient" - "github.com/gravitational/teleport/lib/tbot/bot" "github.com/gravitational/teleport/lib/tbot/config" "github.com/gravitational/teleport/lib/tbot/identity" "github.com/gravitational/teleport/lib/utils" @@ -46,6 +46,29 @@ var tracer = otel.Tracer("github.com/gravitational/teleport/lib/tbot") const componentTBot = "tbot" +// Service is a long-running sub-component of tbot. +type Service interface { + // String returns a human-readable name for the service that can be used + // in logging. It should identify the type of the service and any top + // level configuration that could distinguish it from a same-type service. + String() string + // Run starts the service and blocks until the service exits. It should + // return a nil error if the service exits successfully and an error + // if it is unable to proceed. It should exit gracefully if the context + // is canceled. + Run(ctx context.Context) error +} + +// OneShotService is a [Service] that offers a mode in which it runs a single +// time and then exits. This aligns with the `--oneshot` mode of tbot. +type OneShotService interface { + Service + // OneShot runs the service once and then exits. It should return a nil + // error if the service exits successfully and an error if it is unable + // to proceed. It should exit gracefully if the context is canceled. + OneShot(ctx context.Context) error +} + type Bot struct { cfg *config.BotConfig log logrus.FieldLogger @@ -123,7 +146,7 @@ func (b *Bot) Run(ctx context.Context) error { // Create an error group to manage all the services lifetimes. eg, egCtx := errgroup.WithContext(ctx) - var services []bot.Service + var services []Service // ReloadBroadcaster allows multiple entities to trigger a reload of // all services. This allows os signals and other events such as CA @@ -167,6 +190,16 @@ func (b *Bot) Run(ctx context.Context) error { }() services = append(services, b.botIdentitySvc) + authPingCache := &authPingCache{ + client: b.botIdentitySvc.GetClient(), + log: b.log, + } + proxyPingCache := &proxyPingCache{ + authPingCache: authPingCache, + botCfg: b.cfg, + log: b.log, + } + // Setup all other services if b.cfg.DiagAddr != "" { services = append(services, &diagnosticsService{ @@ -178,6 +211,8 @@ func (b *Bot) Run(ctx context.Context) error { }) } services = append(services, &outputsService{ + authPingCache: authPingCache, + proxyPingCache: proxyPingCache, getBotIdentity: b.botIdentitySvc.GetIdentity, botClient: b.botIdentitySvc.GetClient(), cfg: b.cfg, @@ -196,7 +231,30 @@ func (b *Bot) Run(ctx context.Context) error { reloadBroadcaster: reloadBroadcaster, }) // Append any services configured by the user - services = append(services, b.cfg.Services...) + for _, svcCfg := range b.cfg.Services { + // Convert the service config into the actual service type. + switch svcCfg := svcCfg.(type) { + case *config.DatabaseTunnelService: + svc := &DatabaseTunnelService{ + getBotIdentity: b.botIdentitySvc.GetIdentity, + proxyPingCache: proxyPingCache, + botClient: b.botIdentitySvc.GetClient(), + resolver: resolver, + botCfg: b.cfg, + cfg: svcCfg, + } + svc.log = b.log.WithField( + trace.Component, teleport.Component(componentTBot, "svc", svc.String()), + ) + services = append(services, svc) + case *config.ExampleService: + services = append(services, &ExampleService{ + cfg: svcCfg, + }) + default: + return trace.BadParameter("unknown service type: %T", svcCfg) + } + } b.log.Info("Initialization complete. Starting services.") // Start services @@ -205,7 +263,7 @@ func (b *Bot) Run(ctx context.Context) error { log := b.log.WithField("service", svc.String()) if b.cfg.Oneshot { - svc, ok := svc.(bot.OneShotService) + svc, ok := svc.(OneShotService) // We ignore services with no one-shot implementation if !ok { log.Debug("Service does not support oneshot mode, ignoring.") @@ -399,3 +457,78 @@ func clientForFacade( c, err := authclient.Connect(ctx, authClientConfig) return c, trace.Wrap(err) } + +type authPingCache struct { + client *auth.Client + log logrus.FieldLogger + + mu sync.RWMutex + cachedValue *proto.PingResponse +} + +func (a *authPingCache) ping(ctx context.Context) (proto.PingResponse, error) { + a.mu.Lock() + defer a.mu.Unlock() + if a.cachedValue != nil { + return *a.cachedValue, nil + } + + a.log.Debug("Pinging auth server.") + res, err := a.client.Ping(ctx) + if err != nil { + a.log.WithError(err).Error("Failed to ping auth server.") + return proto.PingResponse{}, trace.Wrap(err) + } + a.cachedValue = &res + a.log.WithField("pong", res).Debug("Successfully pinged auth server.") + + return *a.cachedValue, nil +} + +type proxyPingCache struct { + authPingCache *authPingCache + botCfg *config.BotConfig + log logrus.FieldLogger + + mu sync.RWMutex + cachedValue *webclient.PingResponse +} + +func (p *proxyPingCache) ping(ctx context.Context) (*webclient.PingResponse, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.cachedValue != nil { + return p.cachedValue, nil + } + + // Determine the Proxy address to use. + addr, addrKind := p.botCfg.Address() + switch addrKind { + case config.AddressKindAuth: + // If the address is an auth address, ping auth to determine proxy addr. + authPong, err := p.authPingCache.ping(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + addr = authPong.ProxyPublicAddr + case config.AddressKindProxy: + // If the address is a proxy address, use it directly. + default: + return nil, trace.BadParameter("unsupported address kind: %v", addrKind) + } + + p.log.WithField("addr", addr).Debug("Pinging proxy.") + res, err := webclient.Find(&webclient.Config{ + Context: ctx, + ProxyAddr: addr, + Insecure: p.botCfg.Insecure, + }) + if err != nil { + p.log.WithError(err).Error("Failed to ping proxy.") + return nil, trace.Wrap(err) + } + p.log.WithField("pong", res).Debug("Successfully pinged proxy.") + p.cachedValue = res + + return p.cachedValue, nil +} diff --git a/lib/tbot/tbot_test.go b/lib/tbot/tbot_test.go index 47da5e5fc5da3..34ec59a8b123d 100644 --- a/lib/tbot/tbot_test.go +++ b/lib/tbot/tbot_test.go @@ -21,16 +21,24 @@ import ( "context" "crypto/rand" "fmt" + "net" "os" "testing" + "time" + "github.com/jackc/pgconn" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/sshutils" + "github.com/gravitational/teleport/integration/helpers" "github.com/gravitational/teleport/lib/auth/native" + "github.com/gravitational/teleport/lib/service/servicecfg" + "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/postgres" apisshutils "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tbot/bot" "github.com/gravitational/teleport/lib/tbot/botfs" @@ -599,3 +607,98 @@ func newMockDiscoveredKubeCluster(t *testing.T, name, discoveredName string) *ty require.NoError(t, err) return kubeCluster } + +func TestBotDatabaseTunnel(t *testing.T) { + t.Parallel() + ctx := context.Background() + log := utils.NewLoggerForTests() + + // Make a new auth server. + fc, fds := testhelpers.DefaultConfig(t) + process := testhelpers.MakeAndRunTestAuthServer(t, log, fc, fds) + rootClient := testhelpers.MakeDefaultAuthClient(t, log, fc) + + // Make fake postgres server and add a database access instance to expose + // it. + pts, err := postgres.NewTestServer(common.TestServerConfig{ + AuthClient: rootClient, + Users: []string{"llama"}, + }) + require.NoError(t, err) + go func() { + t.Logf("Postgres Fake server running at %s port", pts.Port()) + require.NoError(t, pts.Serve()) + }() + t.Cleanup(func() { + pts.Close() + }) + proxyAddr, err := process.ProxyWebAddr() + require.NoError(t, err) + helpers.MakeTestDatabaseServer(t, *proxyAddr, testhelpers.AgentJoinToken, nil, servicecfg.Database{ + Name: "test-database", + URI: net.JoinHostPort("localhost", pts.Port()), + Protocol: "postgres", + }) + + // Create role that allows the bot to access the database. + role, err := types.NewRole("database-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + DatabaseLabels: types.Labels{ + "*": apiutils.Strings{"*"}, + }, + DatabaseNames: []string{"mydb"}, + DatabaseUsers: []string{"llama"}, + }, + }) + require.NoError(t, err) + err = rootClient.UpsertRole(ctx, role) + require.NoError(t, err) + + // Prepare the bot config + botParams := testhelpers.MakeBot(t, rootClient, "test", role.GetName()) + botConfig := testhelpers.DefaultBotConfig( + t, fc, botParams, []config.Output{}, + testhelpers.DefaultBotConfigOpts{ + UseAuthServer: true, + // Insecure required as the db tunnel will connect to proxies + // self-signed. + Insecure: true, + ServiceConfigs: []config.ServiceConfig{ + &config.DatabaseTunnelService{ + // TODO: Perhaps allow FD or listener to be injected + Listen: "tcp://127.0.0.1:39933", + Service: "test-database", + Database: "mydb", + Username: "llama", + }, + }, + }, + ) + botConfig.Oneshot = false + b := New(botConfig, log) + + // Spin up goroutine for bot to run in + botCtx, cancelBot := context.WithCancel(ctx) + botCh := make(chan error, 1) + go func() { + botCh <- b.Run(botCtx) + }() + + // We can't predict exactly when the tunnel will be ready so we use + // EventuallyWithT to retry. + require.EventuallyWithT(t, func(t *assert.CollectT) { + conn, err := pgconn.Connect(ctx, "postgres://127.0.0.1:39933/mydb?user=llama") + if !assert.NoError(t, err) { + return + } + defer func() { + conn.Close(ctx) + }() + _, err = conn.Exec(ctx, "SELECT 1;").ReadAll() + assert.NoError(t, err) + }, 5*time.Second, 100*time.Millisecond) + + // Shut down bot and make sure it exits cleanly. + cancelBot() + require.NoError(t, <-botCh) +} diff --git a/lib/tbot/testhelpers/srv.go b/lib/tbot/testhelpers/srv.go index 7c8015f5e2a37..b1f15390bbac5 100644 --- a/lib/tbot/testhelpers/srv.go +++ b/lib/tbot/testhelpers/srv.go @@ -43,8 +43,12 @@ type DefaultBotConfigOpts struct { // Makes the bot accept an Insecure auth or proxy server Insecure bool + + ServiceConfigs botconfig.ServiceConfigs } +const AgentJoinToken = "i-am-a-join-token" + // DefaultConfig returns a FileConfig to be used in tests, with random listen // addresses that are tied to the listeners returned in the FileDescriptor // slice, which should be passed as exported file descriptors to NewTeleport; @@ -73,6 +77,9 @@ func DefaultConfig(t *testing.T) (*config.FileConfig, []*servicecfg.FileDescript EnabledFlag: "true", ListenAddress: testenv.NewTCPListener(t, service.ListenerAuth, &fds), }, + StaticTokens: config.StaticTokens{ + config.StaticToken("db:" + AgentJoinToken), + }, }, } @@ -176,7 +183,11 @@ func MakeBot(t *testing.T, client *auth.Client, name string, roles ...string) *p // - Uses a memory storage destination // - Does not verify Proxy WebAPI certificates func DefaultBotConfig( - t *testing.T, fc *config.FileConfig, botParams *proto.CreateBotResponse, outputs []botconfig.Output, opts DefaultBotConfigOpts, + t *testing.T, + fc *config.FileConfig, + botParams *proto.CreateBotResponse, + outputs []botconfig.Output, + opts DefaultBotConfigOpts, ) *botconfig.BotConfig { t.Helper() @@ -202,6 +213,7 @@ func DefaultBotConfig( // Set Insecure so the bot will trust the Proxy's webapi default signed // certs. Insecure: opts.Insecure, + Services: opts.ServiceConfigs, } cfg.Onboarding.SetToken(botParams.TokenID)