diff --git a/lib/tbot/bot/service.go b/lib/tbot/bot/service.go
deleted file mode 100644
index d2e8ae5b0973b..0000000000000
--- a/lib/tbot/bot/service.go
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Teleport
- * Copyright (C) 2023 Gravitational, Inc.
- *
- * This program is free software: you can redistribute it and/or modify
- * it under the terms of the GNU Affero General Public License as published by
- * the Free Software Foundation, either version 3 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU Affero General Public License for more details.
- *
- * You should have received a copy of the GNU Affero General Public License
- * along with this program. If not, see .
- */
-
-package bot
-
-import "context"
-
-// Service is a long-running sub-component of tbot.
-type Service interface {
- // String returns a human-readable name for the service that can be used
- // in logging. It should identify the type of the service and any top
- // level configuration that could distinguish it from a same-type service.
- String() string
- // Run starts the service and blocks until the service exits. It should
- // return a nil error if the service exits successfully and an error
- // if it is unable to proceed. It should exit gracefully if the context
- // is canceled.
- Run(ctx context.Context) error
-}
-
-// OneShotService is a [Service] that offers a mode in which it runs a single
-// time and then exits. This aligns with the `--oneshot` mode of tbot.
-type OneShotService interface {
- Service
- // OneShot runs the service once and then exits. It should return a nil
- // error if the service exits successfully and an error if it is unable
- // to proceed. It should exit gracefully if the context is canceled.
- OneShot(ctx context.Context) error
-}
diff --git a/lib/tbot/config/config.go b/lib/tbot/config/config.go
index e27ccd8eade48..37800201c4206 100644
--- a/lib/tbot/config/config.go
+++ b/lib/tbot/config/config.go
@@ -270,7 +270,7 @@ type BotConfig struct {
Onboarding OnboardingConfig `yaml:"onboarding,omitempty"`
Storage *StorageConfig `yaml:"storage,omitempty"`
Outputs Outputs `yaml:"outputs,omitempty"`
- Services Services `yaml:"services,omitempty"`
+ Services ServiceConfigs `yaml:"services,omitempty"`
Debug bool `yaml:"debug"`
AuthServer string `yaml:"auth_server,omitempty"`
@@ -374,6 +374,13 @@ func (conf *BotConfig) CheckAndSetDefaults() error {
}
}
+ // Validate configured services
+ for i, service := range conf.Services {
+ if err := service.CheckAndSetDefaults(); err != nil {
+ return trace.Wrap(err, "validating service[%d]", i)
+ }
+ }
+
if conf.CertificateTTL == 0 {
conf.CertificateTTL = DefaultCertificateTTL
}
@@ -418,11 +425,17 @@ func (conf *BotConfig) CheckAndSetDefaults() error {
return nil
}
-// Services assists polymorphic unmarshaling of a slice of Services.
-type Services []bot.Service
+// ServiceConfig is an interface over the various service configurations.
+type ServiceConfig interface {
+ Type() string
+ CheckAndSetDefaults() error
+}
+
+// ServiceConfigs assists polymorphic unmarshaling of a slice of ServiceConfigs.
+type ServiceConfigs []ServiceConfig
-func (o *Services) UnmarshalYAML(node *yaml.Node) error {
- var out []bot.Service
+func (o *ServiceConfigs) UnmarshalYAML(node *yaml.Node) error {
+ var out []ServiceConfig
for _, node := range node.Content {
header := struct {
Type string `yaml:"type"`
@@ -438,6 +451,12 @@ func (o *Services) UnmarshalYAML(node *yaml.Node) error {
return trace.Wrap(err)
}
out = append(out, v)
+ case DatabaseTunnelServiceType:
+ v := &DatabaseTunnelService{}
+ if err := node.Decode(v); err != nil {
+ return trace.Wrap(err)
+ }
+ out = append(out, v)
default:
return trace.BadParameter("unrecognized service type (%s)", header.Type)
}
diff --git a/lib/tbot/config/config_test.go b/lib/tbot/config/config_test.go
index e784a470bde0e..f328819e813c3 100644
--- a/lib/tbot/config/config_test.go
+++ b/lib/tbot/config/config_test.go
@@ -248,7 +248,7 @@ func TestBotConfig_YAML(t *testing.T) {
},
},
},
- Services: []bot.Service{
+ Services: []ServiceConfig{
&ExampleService{
Message: "llama",
},
diff --git a/lib/tbot/config/output_client_credential.go b/lib/tbot/config/output_client_credential.go
index e2002b456a6b3..e4f37d2ea9d14 100644
--- a/lib/tbot/config/output_client_credential.go
+++ b/lib/tbot/config/output_client_credential.go
@@ -94,6 +94,16 @@ func (o *UnstableClientCredentialOutput) SSHClientConfig() (*ssh.ClientConfig, e
return o.facade.SSHClientConfig()
}
+// Facade returns the underlying facade
+func (o *UnstableClientCredentialOutput) Facade() (*identity.Facade, error) {
+ o.mu.Lock()
+ defer o.mu.Unlock()
+ if o.facade == nil {
+ return nil, trace.BadParameter("credentials not yet ready")
+ }
+ return o.facade, nil
+}
+
// Render implements the Destination interface and is called regularly by the
// bot with new credentials. Render passes these credentials down to the
// underlying facade so that they can be used in TLS/SSH configs.
diff --git a/lib/tbot/config/service_database_tunnel.go b/lib/tbot/config/service_database_tunnel.go
new file mode 100644
index 0000000000000..4b0497f7a5857
--- /dev/null
+++ b/lib/tbot/config/service_database_tunnel.go
@@ -0,0 +1,82 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package config
+
+import (
+ "net/url"
+
+ "github.com/gravitational/trace"
+ "gopkg.in/yaml.v3"
+)
+
+const DatabaseTunnelServiceType = "database-tunnel"
+
+// DatabaseTunnelService opens an authenticated tunnel for Database Access.
+type DatabaseTunnelService struct {
+ // Listen is the address on which database tunnel should listen. Example:
+ // - "tcp://127.0.0.1:3306"
+ // - "tcp://0.0.0.0:3306
+ Listen string `yaml:"listen"`
+ // Roles is the list of roles to request for the tunnel.
+ // If empty, it defaults to all the bot's roles.
+ Roles []string `yaml:"roles,omitempty"`
+ // Service is the service name of the Teleport database. Generally this is
+ // the name of the Teleport resource. This field is required for all types
+ // of database.
+ Service string `yaml:"service"`
+ // Database is the name of the database to proxy to.
+ Database string `yaml:"database"`
+ // Username is the database username to proxy as.
+ Username string `yaml:"username"`
+}
+
+func (s *DatabaseTunnelService) Type() string {
+ return DatabaseTunnelServiceType
+}
+
+func (s *DatabaseTunnelService) MarshalYAML() (interface{}, error) {
+ type raw DatabaseTunnelService
+ return withTypeHeader((*raw)(s), DatabaseTunnelServiceType)
+}
+
+func (s *DatabaseTunnelService) UnmarshalYAML(node *yaml.Node) error {
+ // Alias type to remove UnmarshalYAML to avoid recursion
+ type raw DatabaseTunnelService
+ if err := node.Decode((*raw)(s)); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+func (s *DatabaseTunnelService) CheckAndSetDefaults() error {
+ switch {
+ case s.Listen == "":
+ return trace.BadParameter("listen: should not be empty")
+ case s.Service == "":
+ return trace.BadParameter("service: should not be empty")
+ case s.Database == "":
+ return trace.BadParameter("database: should not be empty")
+ case s.Username == "":
+ return trace.BadParameter("username: should not be empty")
+ }
+ if _, err := url.Parse(s.Listen); err != nil {
+ return trace.Wrap(err, "parsing listen")
+ }
+ return nil
+}
diff --git a/lib/tbot/config/service_database_tunnel_test.go b/lib/tbot/config/service_database_tunnel_test.go
new file mode 100644
index 0000000000000..eb6303d8f0153
--- /dev/null
+++ b/lib/tbot/config/service_database_tunnel_test.go
@@ -0,0 +1,108 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package config
+
+import "testing"
+
+func TestDatabaseTunnelService_YAML(t *testing.T) {
+ t.Parallel()
+
+ tests := []testYAMLCase[DatabaseTunnelService]{
+ {
+ name: "full",
+ in: DatabaseTunnelService{
+ Listen: "tcp://0.0.0.0:3621",
+ Roles: []string{"role1", "role2"},
+ Service: "service",
+ Database: "database",
+ Username: "username",
+ },
+ },
+ }
+ testYAML(t, tests)
+}
+
+func TestDatabaseTunnelService_CheckAndSetDefaults(t *testing.T) {
+ t.Parallel()
+
+ tests := []testCheckAndSetDefaultsCase[*DatabaseTunnelService]{
+ {
+ name: "valid",
+ in: func() *DatabaseTunnelService {
+ return &DatabaseTunnelService{
+ Listen: "tcp://0.0.0.0:3621",
+ Roles: []string{"role1", "role2"},
+ Service: "service",
+ Database: "database",
+ Username: "username",
+ }
+ },
+ wantErr: "",
+ },
+ {
+ name: "missing listen",
+ in: func() *DatabaseTunnelService {
+ return &DatabaseTunnelService{
+ Roles: []string{"role1", "role2"},
+ Service: "service",
+ Database: "database",
+ Username: "username",
+ }
+ },
+ wantErr: "listen: should not be empty",
+ },
+ {
+ name: "missing service",
+ in: func() *DatabaseTunnelService {
+ return &DatabaseTunnelService{
+ Listen: "tcp://0.0.0.0:3621",
+ Roles: []string{"role1", "role2"},
+ Database: "database",
+ Username: "username",
+ }
+ },
+ wantErr: "service: should not be empty",
+ },
+ {
+ name: "missing database",
+ in: func() *DatabaseTunnelService {
+ return &DatabaseTunnelService{
+ Listen: "tcp://0.0.0.0:3621",
+ Roles: []string{"role1", "role2"},
+ Service: "service",
+ Username: "username",
+ }
+ },
+ wantErr: "database: should not be empty",
+ },
+ {
+ name: "missing username",
+ in: func() *DatabaseTunnelService {
+ return &DatabaseTunnelService{
+ Listen: "tcp://0.0.0.0:3621",
+ Roles: []string{"role1", "role2"},
+ Service: "service",
+ Database: "database",
+ }
+ },
+ wantErr: "username: should not be empty",
+ },
+ }
+ testCheckAndSetDefaults(t, tests)
+}
diff --git a/lib/tbot/config/service_example.go b/lib/tbot/config/service_example.go
index f6d1172f59d9d..471b53bf5c1cd 100644
--- a/lib/tbot/config/service_example.go
+++ b/lib/tbot/config/service_example.go
@@ -19,10 +19,6 @@
package config
import (
- "context"
- "fmt"
- "time"
-
"github.com/gravitational/trace"
"gopkg.in/yaml.v3"
)
@@ -36,17 +32,6 @@ type ExampleService struct {
Message string `yaml:"message"`
}
-func (s *ExampleService) Run(ctx context.Context) error {
- for {
- select {
- case <-ctx.Done():
- return nil
- case <-time.After(time.Second * 5):
- fmt.Println("Example Service prints message:", s.Message)
- }
- }
-}
-
func (s *ExampleService) Type() string {
return ExampleServiceType
}
@@ -65,6 +50,9 @@ func (s *ExampleService) UnmarshalYAML(node *yaml.Node) error {
return nil
}
-func (s *ExampleService) String() string {
- return fmt.Sprintf("%s:%s", ExampleServiceType, s.Message)
+func (s *ExampleService) CheckAndSetDefaults() error {
+ if s.Message == "" {
+ return trace.BadParameter("message: should not be empty")
+ }
+ return nil
}
diff --git a/lib/tbot/config/testdata/TestDatabaseTunnelService_YAML/full.golden b/lib/tbot/config/testdata/TestDatabaseTunnelService_YAML/full.golden
new file mode 100644
index 0000000000000..f56e8dfa7c3c5
--- /dev/null
+++ b/lib/tbot/config/testdata/TestDatabaseTunnelService_YAML/full.golden
@@ -0,0 +1,8 @@
+type: database-tunnel
+listen: tcp://0.0.0.0:3621
+roles:
+ - role1
+ - role2
+service: service
+database: database
+username: username
diff --git a/lib/tbot/database_access.go b/lib/tbot/database_access.go
new file mode 100644
index 0000000000000..cce260d6a4c88
--- /dev/null
+++ b/lib/tbot/database_access.go
@@ -0,0 +1,96 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tbot
+
+import (
+ "context"
+
+ "github.com/gravitational/trace"
+ "github.com/sirupsen/logrus"
+
+ apiclient "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/defaults"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/auth"
+ libdefaults "github.com/gravitational/teleport/lib/defaults"
+)
+
+func getDatabase(ctx context.Context, clt *auth.Client, name string) (types.Database, error) {
+ ctx, span := tracer.Start(ctx, "getDatabase")
+ defer span.End()
+
+ servers, err := apiclient.GetAllResources[types.DatabaseServer](ctx, clt, &proto.ListResourcesRequest{
+ Namespace: defaults.Namespace,
+ ResourceType: types.KindDatabaseServer,
+ PredicateExpression: makeNameOrDiscoveredNamePredicate(name),
+ Limit: int32(defaults.DefaultChunkSize),
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ var databases []types.Database
+ for _, server := range servers {
+ databases = append(databases, server.GetDatabase())
+ }
+
+ databases = types.DeduplicateDatabases(databases)
+ db, err := chooseOneDatabase(databases, name)
+ return db, trace.Wrap(err)
+}
+
+func getRouteToDatabase(
+ ctx context.Context,
+ log logrus.FieldLogger,
+ client *auth.Client,
+ service string,
+ username string,
+ database string,
+) (proto.RouteToDatabase, error) {
+ ctx, span := tracer.Start(ctx, "getRouteToDatabase")
+ defer span.End()
+
+ if service == "" {
+ return proto.RouteToDatabase{}, nil
+ }
+
+ db, err := getDatabase(ctx, client, service)
+ if err != nil {
+ return proto.RouteToDatabase{}, trace.Wrap(err)
+ }
+ // make sure the output matches the fully resolved db name, since it may
+ // have been just a "discovered name".
+ service = db.GetName()
+ if db.GetProtocol() == libdefaults.ProtocolMongoDB && username == "" {
+ // This isn't strictly a runtime error so killing the process seems
+ // wrong. We'll just loudly warn about it.
+ log.Errorf("Database `username` field for %q is unset but is required for MongoDB databases.", service)
+ } else if db.GetProtocol() == libdefaults.ProtocolRedis && username == "" {
+ // Per tsh's lead, fall back to the default username.
+ username = libdefaults.DefaultRedisUsername
+ }
+
+ return proto.RouteToDatabase{
+ ServiceName: service,
+ Protocol: db.GetProtocol(),
+ Database: database,
+ Username: username,
+ }, nil
+}
diff --git a/lib/tbot/service_database_tunnel.go b/lib/tbot/service_database_tunnel.go
new file mode 100644
index 0000000000000..7cd1c7ce67929
--- /dev/null
+++ b/lib/tbot/service_database_tunnel.go
@@ -0,0 +1,297 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tbot
+
+import (
+ "context"
+ "crypto/tls"
+ "fmt"
+ "net"
+ "net/url"
+
+ "github.com/gravitational/trace"
+ "github.com/sirupsen/logrus"
+
+ "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/lib/auth"
+ "github.com/gravitational/teleport/lib/reversetunnelclient"
+ "github.com/gravitational/teleport/lib/srv/alpnproxy"
+ "github.com/gravitational/teleport/lib/srv/alpnproxy/common"
+ "github.com/gravitational/teleport/lib/tbot/config"
+ "github.com/gravitational/teleport/lib/tbot/identity"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+var _ alpnproxy.LocalProxyMiddleware = (*alpnProxyMiddleware)(nil)
+
+type alpnProxyMiddleware struct {
+ onNewConnection func(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error
+ onStart func(ctx context.Context, lp *alpnproxy.LocalProxy) error
+}
+
+func (a alpnProxyMiddleware) OnNewConnection(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error {
+ if a.onNewConnection != nil {
+ return a.onNewConnection(ctx, lp, conn)
+ }
+ return nil
+}
+
+func (a alpnProxyMiddleware) OnStart(ctx context.Context, lp *alpnproxy.LocalProxy) error {
+ if a.onStart != nil {
+ return a.onStart(ctx, lp)
+ }
+ return nil
+}
+
+// DatabaseTunnelService is a service that listens on a local port and forwards
+// connections to a remote database service. It is an authenticating tunnel and
+// will automatically issue and renew certificates as needed.
+type DatabaseTunnelService struct {
+ botCfg *config.BotConfig
+ cfg *config.DatabaseTunnelService
+ proxyPingCache *proxyPingCache
+ log logrus.FieldLogger
+ resolver reversetunnelclient.Resolver
+ botClient *auth.Client
+ getBotIdentity getBotIdentityFn
+}
+
+// buildLocalProxyConfig initializes the service, fetching any initial information and setting
+// up the localproxy.
+func (s *DatabaseTunnelService) buildLocalProxyConfig(ctx context.Context) (lpCfg alpnproxy.LocalProxyConfig, err error) {
+ ctx, span := tracer.Start(ctx, "DatabaseTunnelService/buildLocalProxyConfig")
+ defer span.End()
+
+ // Determine the roles to use for the impersonated db access user. We fall
+ // back to all the roles the bot has if none are configured.
+ roles := s.cfg.Roles
+ if len(roles) == 0 {
+ roles, err = fetchDefaultRoles(ctx, s.botClient, s.getBotIdentity())
+ if err != nil {
+ return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "fetching default roles")
+ }
+ s.log.WithField("roles", roles).Debug("No roles configured, using all roles available.")
+ }
+
+ proxyPing, err := s.proxyPingCache.ping(ctx)
+ if err != nil {
+ return alpnproxy.LocalProxyConfig{}, trace.Wrap(err, "pinging proxy")
+ }
+ proxyAddr := proxyPing.Proxy.SSH.PublicAddr
+
+ // Fetch information about the database and then issue the initial
+ // certificate. We issue the initial certificate to allow us to fail faster.
+ // We cache the routeToDatabase as these will not change during the lifetime
+ // of the service and this reduces the time needed to issue a new
+ // certificate.
+ s.log.Debug("Determining route to database.")
+ routeToDatabase, err := s.getRouteToDatabaseWithImpersonation(ctx, roles)
+ if err != nil {
+ return alpnproxy.LocalProxyConfig{}, trace.Wrap(err)
+ }
+ s.log.WithFields(logrus.Fields{
+ "serviceName": routeToDatabase.ServiceName,
+ "protocol": routeToDatabase.Protocol,
+ "database": routeToDatabase.Database,
+ "username": routeToDatabase.Username,
+ }).Debug("Identified route to database.")
+
+ s.log.Debug("Issuing initial certificate for local proxy.")
+ dbCert, err := s.issueCert(ctx, routeToDatabase, roles)
+ if err != nil {
+ return alpnproxy.LocalProxyConfig{}, trace.Wrap(err)
+ }
+ s.log.Debug("Issued initial certificate for local proxy.")
+
+ middleware := alpnProxyMiddleware{
+ onNewConnection: func(ctx context.Context, lp *alpnproxy.LocalProxy, conn net.Conn) error {
+ ctx, span := tracer.Start(ctx, "DatabaseTunnelService/OnNewConnection")
+ defer span.End()
+
+ // Check if the certificate needs reissuing, if so, reissue.
+ if err := lp.CheckDBCerts(tlsca.RouteToDatabase{
+ ServiceName: routeToDatabase.ServiceName,
+ Protocol: routeToDatabase.Protocol,
+ Database: routeToDatabase.Database,
+ Username: routeToDatabase.Username,
+ }); err != nil {
+ s.log.WithField("reason", err.Error()).Info("Certificate for tunnel needs reissuing.")
+ cert, err := s.issueCert(ctx, routeToDatabase, roles)
+ if err != nil {
+ return trace.Wrap(err, "issuing cert")
+ }
+ lp.SetCerts([]tls.Certificate{*cert})
+ }
+ return nil
+ },
+ }
+
+ alpnProtocol, err := common.ToALPNProtocol(routeToDatabase.Protocol)
+ if err != nil {
+ return alpnproxy.LocalProxyConfig{}, trace.Wrap(err)
+
+ }
+ lpConfig := alpnproxy.LocalProxyConfig{
+ Middleware: middleware,
+
+ RemoteProxyAddr: proxyAddr,
+ ParentContext: ctx,
+ Protocols: []common.Protocol{alpnProtocol},
+ Certs: []tls.Certificate{*dbCert},
+ InsecureSkipVerify: s.botCfg.Insecure,
+ }
+ if client.IsALPNConnUpgradeRequired(
+ ctx,
+ proxyAddr,
+ s.botCfg.Insecure,
+ ) {
+ lpConfig.ALPNConnUpgradeRequired = true
+ // If ALPN Conn Upgrade will be used, we need to set the cluster CAs
+ // to validate the Proxy's auth issued host cert.
+ lpConfig.RootCAs = s.getBotIdentity().TLSCAPool
+ }
+
+ return lpConfig, nil
+}
+
+func (s *DatabaseTunnelService) Run(ctx context.Context) error {
+ ctx, span := tracer.Start(ctx, "DatabaseTunnelService/Run")
+ defer span.End()
+
+ listenUrl, err := url.Parse(s.cfg.Listen)
+ if err != nil {
+ return trace.Wrap(err, "parsing listen url")
+ }
+
+ s.log.WithField("address", listenUrl.String()).Debug("Opening listener for database tunnel.")
+ l, err := net.Listen("tcp", listenUrl.Host)
+ if err != nil {
+ return trace.Wrap(err, "opening listener")
+ }
+ defer func() {
+ if err := l.Close(); err != nil && !utils.IsUseOfClosedNetworkError(err) {
+ s.log.WithError(err).Error("Failed to close listener")
+ }
+ }()
+
+ lpCfg, err := s.buildLocalProxyConfig(ctx)
+ if err != nil {
+ return trace.Wrap(err, "building local proxy config")
+ }
+ lpCfg.Listener = l
+
+ lp, err := alpnproxy.NewLocalProxy(lpCfg)
+ if err != nil {
+ return trace.Wrap(err, "creating local proxy")
+ }
+ defer func() {
+ if err := lp.Close(); err != nil {
+ s.log.WithError(err).Error("Failed to close local proxy")
+ }
+ }()
+ // Closed further down.
+
+ // lp.Start will block and continues to block until lp.Close() is called.
+ // Despite taking a context, it will not exit until the first connection is
+ // made after the context is canceled.
+ var errCh = make(chan error, 1)
+ go func() {
+ errCh <- lp.Start(ctx)
+ }()
+ s.log.WithField("address", l.Addr().String()).Info("Listening for connections.")
+
+ select {
+ case <-ctx.Done():
+ return nil
+ case err := <-errCh:
+ return trace.Wrap(err, "local proxy failed")
+ }
+}
+
+// getRouteToDatabaseWithImpersonation fetches the route to the database with
+// impersonation of roles. This ensures that the user's selected roles actually
+// grant access to the database.
+func (s *DatabaseTunnelService) getRouteToDatabaseWithImpersonation(ctx context.Context, roles []string) (proto.RouteToDatabase, error) {
+ ctx, span := tracer.Start(ctx, "DatabaseTunnelService/getRouteToDatabaseWithImpersonation")
+ defer span.End()
+
+ impersonatedIdentity, err := generateIdentity(
+ ctx,
+ s.botClient,
+ s.getBotIdentity(),
+ roles,
+ s.botCfg.CertificateTTL,
+ nil,
+ )
+ if err != nil {
+ return proto.RouteToDatabase{}, trace.Wrap(err)
+ }
+
+ impersonatedClient, err := clientForFacade(
+ ctx,
+ s.log,
+ s.botCfg,
+ identity.NewFacade(s.botCfg.FIPS, s.botCfg.Insecure, impersonatedIdentity),
+ s.resolver,
+ )
+ if err != nil {
+ return proto.RouteToDatabase{}, trace.Wrap(err)
+ }
+ defer func() {
+ if err := impersonatedClient.Close(); err != nil {
+ s.log.WithError(err).Error("Failed to close impersonated client.")
+ }
+ }()
+
+ return getRouteToDatabase(ctx, s.log, impersonatedClient, s.cfg.Service, s.cfg.Username, s.cfg.Database)
+}
+
+func (s *DatabaseTunnelService) issueCert(
+ ctx context.Context,
+ route proto.RouteToDatabase,
+ roles []string,
+) (*tls.Certificate, error) {
+ ctx, span := tracer.Start(ctx, "DatabaseTunnelService/issueCert")
+ defer span.End()
+
+ s.log.Debug("Requesting issuance of certificate for tunnel proxy.")
+ ident, err := generateIdentity(
+ ctx,
+ s.botClient,
+ s.getBotIdentity(),
+ roles,
+ s.botCfg.CertificateTTL,
+ func(req *proto.UserCertsRequest) {
+ req.RouteToDatabase = route
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ s.log.Info("Certificate issued for tunnel proxy.")
+
+ return ident.TLSCert, nil
+}
+
+// String returns a human-readable string that can uniquely identify the
+// service.
+func (s *DatabaseTunnelService) String() string {
+ return fmt.Sprintf("%s:%s:%s", config.DatabaseTunnelServiceType, s.cfg.Listen, s.cfg.Service)
+}
diff --git a/lib/tbot/service_example.go b/lib/tbot/service_example.go
new file mode 100644
index 0000000000000..c550c52d9f24f
--- /dev/null
+++ b/lib/tbot/service_example.go
@@ -0,0 +1,50 @@
+/*
+ * Teleport
+ * Copyright (C) 2024 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package tbot
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "github.com/gravitational/teleport/lib/tbot/config"
+)
+
+// ExampleService is a temporary example service for testing purposes. It is
+// not intended to be used and exists to demonstrate how a user configurable
+// service integrates with the tbot service manager.
+type ExampleService struct {
+ cfg *config.ExampleService
+ Message string `yaml:"message"`
+}
+
+func (s *ExampleService) Run(ctx context.Context) error {
+ for {
+ select {
+ case <-ctx.Done():
+ return nil
+ case <-time.After(time.Second * 5):
+ fmt.Println("Example Service prints message:", s.Message)
+ }
+ }
+}
+
+func (s *ExampleService) String() string {
+ return fmt.Sprintf("%s:%s", config.ExampleServiceType, s.Message)
+}
diff --git a/lib/tbot/service_outputs.go b/lib/tbot/service_outputs.go
index ec68b3a7ddc33..d7e7775610869 100644
--- a/lib/tbot/service_outputs.go
+++ b/lib/tbot/service_outputs.go
@@ -38,7 +38,6 @@ import (
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/native"
- libdefaults "github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/reversetunnelclient"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/tbot/config"
@@ -56,6 +55,8 @@ const renewalRetryLimit = 5
type outputsService struct {
log logrus.FieldLogger
reloadBroadcaster *channelBroadcaster
+ proxyPingCache *proxyPingCache
+ authPingCache *authPingCache
botClient *auth.Client
getBotIdentity getBotIdentityFn
cfg *config.BotConfig
@@ -86,16 +87,16 @@ func (s *outputsService) renewOutputs(
// create a cache shared across outputs so they don't hammer the auth
// server with similar requests
drc := &outputRenewalCache{
- client: s.botClient,
- cfg: s.cfg,
+ proxyPingCache: s.proxyPingCache,
+ authPingCache: s.authPingCache,
+ client: s.botClient,
+ cfg: s.cfg,
}
// Determine the default role list based on the bot role. The role's
// name should match the certificate's Key ID (user and role names
// should all match bot-$name)
- botIdentity := s.getBotIdentity()
- botResourceName := botIdentity.X509Cert.Subject.CommonName
- defaultRoles, err := fetchDefaultRoles(ctx, s.botClient, botResourceName)
+ defaultRoles, err := fetchDefaultRoles(ctx, s.botClient, s.getBotIdentity())
if err != nil {
s.log.WithError(err).Warnf("Unable to determine default roles, no roles will be requested if unspecified")
defaultRoles = []string{}
@@ -124,7 +125,7 @@ func (s *outputsService) renewOutputs(
}
impersonatedIdentity, impersonatedClient, err := s.generateImpersonatedIdentity(
- ctx, s.botClient, botIdentity, output, defaultRoles,
+ ctx, s.botClient, s.getBotIdentity(), output, defaultRoles,
)
if err != nil {
return trace.Wrap(err, "generating impersonated certs for output: %s", output)
@@ -286,15 +287,15 @@ type identityConfigurator = func(req *proto.UserCertsRequest)
// impersonated identity that already has the relevant permissions, much like
// `tsh (app|db|kube) login` is already used to generate an additional set of
// certs.
-func (s *outputsService) generateIdentity(
+func generateIdentity(
ctx context.Context,
client *auth.Client,
currentIdentity *identity.Identity,
- output config.Output,
- defaultRoles []string,
+ roles []string,
+ ttl time.Duration,
configurator identityConfigurator,
) (*identity.Identity, error) {
- ctx, span := tracer.Start(ctx, "outputsService/generateIdentity")
+ ctx, span := tracer.Start(ctx, "generateIdentity")
defer span.End()
// TODO: enforce expiration > renewal period (by what margin?)
@@ -309,19 +310,11 @@ func (s *outputsService) generateIdentity(
return nil, trace.Wrap(err)
}
- var roleRequests []string
- if roles := output.GetRoles(); len(roles) > 0 {
- roleRequests = roles
- } else {
- s.log.Debugf("Output specified no roles, defaults will be requested: %v", defaultRoles)
- roleRequests = defaultRoles
- }
-
req := proto.UserCertsRequest{
PublicKey: publicKey,
Username: currentIdentity.X509Cert.Subject.CommonName,
- Expires: time.Now().Add(s.cfg.CertificateTTL),
- RoleRequests: roleRequests,
+ Expires: time.Now().Add(ttl),
+ RoleRequests: roles,
RouteToCluster: currentIdentity.ClusterName,
// Make sure to specify this is an impersonated cert request. If unset,
@@ -380,64 +373,6 @@ func (s *outputsService) generateIdentity(
return newIdentity, nil
}
-func getDatabase(ctx context.Context, clt *auth.Client, name string) (types.Database, error) {
- ctx, span := tracer.Start(ctx, "getDatabase")
- defer span.End()
-
- servers, err := apiclient.GetAllResources[types.DatabaseServer](ctx, clt, &proto.ListResourcesRequest{
- Namespace: defaults.Namespace,
- ResourceType: types.KindDatabaseServer,
- PredicateExpression: makeNameOrDiscoveredNamePredicate(name),
- Limit: int32(defaults.DefaultChunkSize),
- })
- if err != nil {
- return nil, trace.Wrap(err)
- }
-
- var databases []types.Database
- for _, server := range servers {
- databases = append(databases, server.GetDatabase())
- }
-
- databases = types.DeduplicateDatabases(databases)
- db, err := chooseOneDatabase(databases, name)
- return db, trace.Wrap(err)
-}
-
-func (s *outputsService) getRouteToDatabase(ctx context.Context, client *auth.Client, output *config.DatabaseOutput) (proto.RouteToDatabase, error) {
- ctx, span := tracer.Start(ctx, "outputsService/getRouteToDatabase")
- defer span.End()
-
- if output.Service == "" {
- return proto.RouteToDatabase{}, nil
- }
-
- db, err := getDatabase(ctx, client, output.Service)
- if err != nil {
- return proto.RouteToDatabase{}, trace.Wrap(err)
- }
- // make sure the output matches the fully resolved db name, since it may
- // have been just a "discovered name".
- output.Service = db.GetName()
-
- username := output.Username
- if db.GetProtocol() == libdefaults.ProtocolMongoDB && username == "" {
- // This isn't strictly a runtime error so killing the process seems
- // wrong. We'll just loudly warn about it.
- s.log.Errorf("Database `username` field for %q is unset but is required for MongoDB databases.", output.Service)
- } else if db.GetProtocol() == libdefaults.ProtocolRedis && username == "" {
- // Per tsh's lead, fall back to the default username.
- username = libdefaults.DefaultRedisUsername
- }
-
- return proto.RouteToDatabase{
- ServiceName: output.Service,
- Protocol: db.GetProtocol(),
- Database: output.Database,
- Username: username,
- }, nil
-}
-
func getKubeCluster(ctx context.Context, clt *auth.Client, name string) (types.KubeCluster, error) {
ctx, span := tracer.Start(ctx, "getKubeCluster")
defer span.End()
@@ -534,8 +469,14 @@ func (s *outputsService) generateImpersonatedIdentity(
ctx, span := tracer.Start(ctx, "outputsService/generateImpersonatedIdentity")
defer span.End()
- impersonatedIdentity, err = s.generateIdentity(
- ctx, botClient, botIdentity, output, defaultRoles, nil,
+ roles := output.GetRoles()
+ if len(roles) == 0 {
+ s.log.Debugf("Output specified no roles, defaults will be requested: %v", defaultRoles)
+ roles = defaultRoles
+ }
+
+ impersonatedIdentity, err = generateIdentity(
+ ctx, botClient, botIdentity, roles, s.cfg.CertificateTTL, nil,
)
if err != nil {
return nil, nil, trace.Wrap(err)
@@ -564,16 +505,29 @@ func (s *outputsService) generateImpersonatedIdentity(
return impersonatedIdentity, impersonatedClient, nil
}
- routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) {
- req.RouteToCluster = output.Cluster
- },
+ routedIdentity, err := generateIdentity(
+ ctx,
+ botClient,
+ impersonatedIdentity,
+ roles,
+ s.cfg.CertificateTTL,
+ func(req *proto.UserCertsRequest) {
+ req.RouteToCluster = output.Cluster
+ },
)
if err != nil {
return nil, nil, trace.Wrap(err)
}
return routedIdentity, impersonatedClient, nil
case *config.DatabaseOutput:
- route, err := s.getRouteToDatabase(ctx, impersonatedClient, output)
+ route, err := getRouteToDatabase(
+ ctx,
+ s.log,
+ impersonatedClient,
+ output.Service,
+ output.Username,
+ output.Database,
+ )
if err != nil {
return nil, nil, trace.Wrap(err)
}
@@ -582,9 +536,16 @@ func (s *outputsService) generateImpersonatedIdentity(
// so we'll request the database access identity using the main bot
// identity (having gathered the necessary info for RouteToDatabase
// using the correct impersonated unroutedIdentity.)
- routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) {
- req.RouteToDatabase = route
- })
+ routedIdentity, err := generateIdentity(
+ ctx,
+ botClient,
+ impersonatedIdentity,
+ roles,
+ s.cfg.CertificateTTL,
+ func(req *proto.UserCertsRequest) {
+ req.RouteToDatabase = route
+ },
+ )
if err != nil {
return nil, nil, trace.Wrap(err)
}
@@ -603,9 +564,16 @@ func (s *outputsService) generateImpersonatedIdentity(
// Note: the Teleport server does attempt to verify k8s cluster names
// and will fail to generate certs if the cluster doesn't exist or is
// offline.
- routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) {
- req.KubernetesCluster = output.KubernetesCluster
- })
+ routedIdentity, err := generateIdentity(
+ ctx,
+ botClient,
+ impersonatedIdentity,
+ roles,
+ s.cfg.CertificateTTL,
+ func(req *proto.UserCertsRequest) {
+ req.KubernetesCluster = output.KubernetesCluster
+ },
+ )
if err != nil {
return nil, nil, trace.Wrap(err)
}
@@ -619,9 +587,16 @@ func (s *outputsService) generateImpersonatedIdentity(
return nil, nil, trace.Wrap(err)
}
- routedIdentity, err := s.generateIdentity(ctx, botClient, impersonatedIdentity, output, defaultRoles, func(req *proto.UserCertsRequest) {
- req.RouteToApp = routeToApp
- })
+ routedIdentity, err := generateIdentity(
+ ctx,
+ botClient,
+ impersonatedIdentity,
+ roles,
+ s.cfg.CertificateTTL,
+ func(req *proto.UserCertsRequest) {
+ req.RouteToApp = routeToApp
+ },
+ )
if err != nil {
return nil, nil, trace.Wrap(err)
}
@@ -640,8 +615,8 @@ func (s *outputsService) generateImpersonatedIdentity(
// fetchDefaultRoles requests the bot's own role from the auth server and
// extracts its full list of allowed roles.
-func fetchDefaultRoles(ctx context.Context, roleGetter services.RoleGetter, botRole string) ([]string, error) {
- role, err := roleGetter.GetRole(ctx, botRole)
+func fetchDefaultRoles(ctx context.Context, roleGetter services.RoleGetter, identity *identity.Identity) ([]string, error) {
+ role, err := roleGetter.GetRole(ctx, identity.X509Cert.Subject.CommonName)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -655,14 +630,14 @@ func fetchDefaultRoles(ctx context.Context, roleGetter services.RoleGetter, botR
// requests for the same information. This is shared between all of the
// outputs.
type outputRenewalCache struct {
- client *auth.Client
+ client *auth.Client
+ cfg *config.BotConfig
+ proxyPingCache *proxyPingCache
+ authPingCache *authPingCache
- cfg *config.BotConfig
- mu sync.Mutex
+ mu sync.Mutex
// These are protected by getter/setters with mutex locks
- _cas map[types.CertAuthType][]types.CertAuthority
- _authPong *proto.PingResponse
- _proxyPong *webclient.PingResponse
+ _cas map[types.CertAuthType][]types.CertAuthority
}
func (orc *outputRenewalCache) getCertAuthorities(
@@ -694,71 +669,23 @@ func (orc *outputRenewalCache) GetCertAuthorities(
return orc.getCertAuthorities(ctx, caType)
}
-func (orc *outputRenewalCache) authPing(ctx context.Context) (*proto.PingResponse, error) {
- if orc._authPong != nil {
- return orc._authPong, nil
- }
-
- pong, err := orc.client.Ping(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- orc._authPong = &pong
-
- return &pong, nil
-}
-
// AuthPing pings the auth server and returns the (possibly cached) response.
func (orc *outputRenewalCache) AuthPing(ctx context.Context) (*proto.PingResponse, error) {
- orc.mu.Lock()
- defer orc.mu.Unlock()
- return orc.authPing(ctx)
-}
-
-func (orc *outputRenewalCache) proxyPing(ctx context.Context) (*webclient.PingResponse, error) {
- if orc._proxyPong != nil {
- return orc._proxyPong, nil
- }
-
- // Determine the Proxy address to use.
- addr, addrKind := orc.cfg.Address()
- switch addrKind {
- case config.AddressKindAuth:
- // If the address is an auth address, ping auth to determine proxy addr.
- authPong, err := orc.authPing(ctx)
- if err != nil {
- return nil, trace.Wrap(err)
- }
- addr = authPong.ProxyPublicAddr
- case config.AddressKindProxy:
- // If the address is a proxy address, use it directly.
- default:
- return nil, trace.BadParameter("unsupported address kind: %v", addrKind)
- }
-
- // We use find instead of Ping as it's less resource intense and we can
- // ping the AuthServer directly for its configuration if necessary.
- proxyPong, err := webclient.Find(&webclient.Config{
- Context: ctx,
- ProxyAddr: addr,
- Insecure: orc.cfg.Insecure,
- })
+ res, err := orc.authPingCache.ping(ctx)
if err != nil {
return nil, trace.Wrap(err)
- }
-
- orc._proxyPong = proxyPong
- return proxyPong, nil
+ }
+ return &res, nil
}
// ProxyPing returns a (possibly cached) ping response from the Teleport proxy.
-// Note that it relies on the auth server being configured with a sane proxy
-// public address.
func (orc *outputRenewalCache) ProxyPing(ctx context.Context) (*webclient.PingResponse, error) {
- orc.mu.Lock()
- defer orc.mu.Unlock()
- return orc.proxyPing(ctx)
+ res, err := orc.proxyPingCache.ping(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return res, nil
}
// Config returns the bots config.
diff --git a/lib/tbot/tbot.go b/lib/tbot/tbot.go
index a893cdee0e695..b0a0355aebec1 100644
--- a/lib/tbot/tbot.go
+++ b/lib/tbot/tbot.go
@@ -29,6 +29,7 @@ import (
"google.golang.org/grpc"
"github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/client/webclient"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/types"
@@ -36,7 +37,6 @@ import (
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/reversetunnelclient"
- "github.com/gravitational/teleport/lib/tbot/bot"
"github.com/gravitational/teleport/lib/tbot/config"
"github.com/gravitational/teleport/lib/tbot/identity"
"github.com/gravitational/teleport/lib/utils"
@@ -46,6 +46,29 @@ var tracer = otel.Tracer("github.com/gravitational/teleport/lib/tbot")
const componentTBot = "tbot"
+// Service is a long-running sub-component of tbot.
+type Service interface {
+ // String returns a human-readable name for the service that can be used
+ // in logging. It should identify the type of the service and any top
+ // level configuration that could distinguish it from a same-type service.
+ String() string
+ // Run starts the service and blocks until the service exits. It should
+ // return a nil error if the service exits successfully and an error
+ // if it is unable to proceed. It should exit gracefully if the context
+ // is canceled.
+ Run(ctx context.Context) error
+}
+
+// OneShotService is a [Service] that offers a mode in which it runs a single
+// time and then exits. This aligns with the `--oneshot` mode of tbot.
+type OneShotService interface {
+ Service
+ // OneShot runs the service once and then exits. It should return a nil
+ // error if the service exits successfully and an error if it is unable
+ // to proceed. It should exit gracefully if the context is canceled.
+ OneShot(ctx context.Context) error
+}
+
type Bot struct {
cfg *config.BotConfig
log logrus.FieldLogger
@@ -123,7 +146,7 @@ func (b *Bot) Run(ctx context.Context) error {
// Create an error group to manage all the services lifetimes.
eg, egCtx := errgroup.WithContext(ctx)
- var services []bot.Service
+ var services []Service
// ReloadBroadcaster allows multiple entities to trigger a reload of
// all services. This allows os signals and other events such as CA
@@ -167,6 +190,16 @@ func (b *Bot) Run(ctx context.Context) error {
}()
services = append(services, b.botIdentitySvc)
+ authPingCache := &authPingCache{
+ client: b.botIdentitySvc.GetClient(),
+ log: b.log,
+ }
+ proxyPingCache := &proxyPingCache{
+ authPingCache: authPingCache,
+ botCfg: b.cfg,
+ log: b.log,
+ }
+
// Setup all other services
if b.cfg.DiagAddr != "" {
services = append(services, &diagnosticsService{
@@ -178,6 +211,8 @@ func (b *Bot) Run(ctx context.Context) error {
})
}
services = append(services, &outputsService{
+ authPingCache: authPingCache,
+ proxyPingCache: proxyPingCache,
getBotIdentity: b.botIdentitySvc.GetIdentity,
botClient: b.botIdentitySvc.GetClient(),
cfg: b.cfg,
@@ -196,7 +231,30 @@ func (b *Bot) Run(ctx context.Context) error {
reloadBroadcaster: reloadBroadcaster,
})
// Append any services configured by the user
- services = append(services, b.cfg.Services...)
+ for _, svcCfg := range b.cfg.Services {
+ // Convert the service config into the actual service type.
+ switch svcCfg := svcCfg.(type) {
+ case *config.DatabaseTunnelService:
+ svc := &DatabaseTunnelService{
+ getBotIdentity: b.botIdentitySvc.GetIdentity,
+ proxyPingCache: proxyPingCache,
+ botClient: b.botIdentitySvc.GetClient(),
+ resolver: resolver,
+ botCfg: b.cfg,
+ cfg: svcCfg,
+ }
+ svc.log = b.log.WithField(
+ trace.Component, teleport.Component(componentTBot, "svc", svc.String()),
+ )
+ services = append(services, svc)
+ case *config.ExampleService:
+ services = append(services, &ExampleService{
+ cfg: svcCfg,
+ })
+ default:
+ return trace.BadParameter("unknown service type: %T", svcCfg)
+ }
+ }
b.log.Info("Initialization complete. Starting services.")
// Start services
@@ -205,7 +263,7 @@ func (b *Bot) Run(ctx context.Context) error {
log := b.log.WithField("service", svc.String())
if b.cfg.Oneshot {
- svc, ok := svc.(bot.OneShotService)
+ svc, ok := svc.(OneShotService)
// We ignore services with no one-shot implementation
if !ok {
log.Debug("Service does not support oneshot mode, ignoring.")
@@ -399,3 +457,78 @@ func clientForFacade(
c, err := authclient.Connect(ctx, authClientConfig)
return c, trace.Wrap(err)
}
+
+type authPingCache struct {
+ client *auth.Client
+ log logrus.FieldLogger
+
+ mu sync.RWMutex
+ cachedValue *proto.PingResponse
+}
+
+func (a *authPingCache) ping(ctx context.Context) (proto.PingResponse, error) {
+ a.mu.Lock()
+ defer a.mu.Unlock()
+ if a.cachedValue != nil {
+ return *a.cachedValue, nil
+ }
+
+ a.log.Debug("Pinging auth server.")
+ res, err := a.client.Ping(ctx)
+ if err != nil {
+ a.log.WithError(err).Error("Failed to ping auth server.")
+ return proto.PingResponse{}, trace.Wrap(err)
+ }
+ a.cachedValue = &res
+ a.log.WithField("pong", res).Debug("Successfully pinged auth server.")
+
+ return *a.cachedValue, nil
+}
+
+type proxyPingCache struct {
+ authPingCache *authPingCache
+ botCfg *config.BotConfig
+ log logrus.FieldLogger
+
+ mu sync.RWMutex
+ cachedValue *webclient.PingResponse
+}
+
+func (p *proxyPingCache) ping(ctx context.Context) (*webclient.PingResponse, error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.cachedValue != nil {
+ return p.cachedValue, nil
+ }
+
+ // Determine the Proxy address to use.
+ addr, addrKind := p.botCfg.Address()
+ switch addrKind {
+ case config.AddressKindAuth:
+ // If the address is an auth address, ping auth to determine proxy addr.
+ authPong, err := p.authPingCache.ping(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ addr = authPong.ProxyPublicAddr
+ case config.AddressKindProxy:
+ // If the address is a proxy address, use it directly.
+ default:
+ return nil, trace.BadParameter("unsupported address kind: %v", addrKind)
+ }
+
+ p.log.WithField("addr", addr).Debug("Pinging proxy.")
+ res, err := webclient.Find(&webclient.Config{
+ Context: ctx,
+ ProxyAddr: addr,
+ Insecure: p.botCfg.Insecure,
+ })
+ if err != nil {
+ p.log.WithError(err).Error("Failed to ping proxy.")
+ return nil, trace.Wrap(err)
+ }
+ p.log.WithField("pong", res).Debug("Successfully pinged proxy.")
+ p.cachedValue = res
+
+ return p.cachedValue, nil
+}
diff --git a/lib/tbot/tbot_test.go b/lib/tbot/tbot_test.go
index 47da5e5fc5da3..34ec59a8b123d 100644
--- a/lib/tbot/tbot_test.go
+++ b/lib/tbot/tbot_test.go
@@ -21,16 +21,24 @@ import (
"context"
"crypto/rand"
"fmt"
+ "net"
"os"
"testing"
+ "time"
+ "github.com/jackc/pgconn"
+ "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/sshutils"
+ "github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib/auth/native"
+ "github.com/gravitational/teleport/lib/service/servicecfg"
+ "github.com/gravitational/teleport/lib/srv/db/common"
+ "github.com/gravitational/teleport/lib/srv/db/postgres"
apisshutils "github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/tbot/bot"
"github.com/gravitational/teleport/lib/tbot/botfs"
@@ -599,3 +607,98 @@ func newMockDiscoveredKubeCluster(t *testing.T, name, discoveredName string) *ty
require.NoError(t, err)
return kubeCluster
}
+
+func TestBotDatabaseTunnel(t *testing.T) {
+ t.Parallel()
+ ctx := context.Background()
+ log := utils.NewLoggerForTests()
+
+ // Make a new auth server.
+ fc, fds := testhelpers.DefaultConfig(t)
+ process := testhelpers.MakeAndRunTestAuthServer(t, log, fc, fds)
+ rootClient := testhelpers.MakeDefaultAuthClient(t, log, fc)
+
+ // Make fake postgres server and add a database access instance to expose
+ // it.
+ pts, err := postgres.NewTestServer(common.TestServerConfig{
+ AuthClient: rootClient,
+ Users: []string{"llama"},
+ })
+ require.NoError(t, err)
+ go func() {
+ t.Logf("Postgres Fake server running at %s port", pts.Port())
+ require.NoError(t, pts.Serve())
+ }()
+ t.Cleanup(func() {
+ pts.Close()
+ })
+ proxyAddr, err := process.ProxyWebAddr()
+ require.NoError(t, err)
+ helpers.MakeTestDatabaseServer(t, *proxyAddr, testhelpers.AgentJoinToken, nil, servicecfg.Database{
+ Name: "test-database",
+ URI: net.JoinHostPort("localhost", pts.Port()),
+ Protocol: "postgres",
+ })
+
+ // Create role that allows the bot to access the database.
+ role, err := types.NewRole("database-access", types.RoleSpecV6{
+ Allow: types.RoleConditions{
+ DatabaseLabels: types.Labels{
+ "*": apiutils.Strings{"*"},
+ },
+ DatabaseNames: []string{"mydb"},
+ DatabaseUsers: []string{"llama"},
+ },
+ })
+ require.NoError(t, err)
+ err = rootClient.UpsertRole(ctx, role)
+ require.NoError(t, err)
+
+ // Prepare the bot config
+ botParams := testhelpers.MakeBot(t, rootClient, "test", role.GetName())
+ botConfig := testhelpers.DefaultBotConfig(
+ t, fc, botParams, []config.Output{},
+ testhelpers.DefaultBotConfigOpts{
+ UseAuthServer: true,
+ // Insecure required as the db tunnel will connect to proxies
+ // self-signed.
+ Insecure: true,
+ ServiceConfigs: []config.ServiceConfig{
+ &config.DatabaseTunnelService{
+ // TODO: Perhaps allow FD or listener to be injected
+ Listen: "tcp://127.0.0.1:39933",
+ Service: "test-database",
+ Database: "mydb",
+ Username: "llama",
+ },
+ },
+ },
+ )
+ botConfig.Oneshot = false
+ b := New(botConfig, log)
+
+ // Spin up goroutine for bot to run in
+ botCtx, cancelBot := context.WithCancel(ctx)
+ botCh := make(chan error, 1)
+ go func() {
+ botCh <- b.Run(botCtx)
+ }()
+
+ // We can't predict exactly when the tunnel will be ready so we use
+ // EventuallyWithT to retry.
+ require.EventuallyWithT(t, func(t *assert.CollectT) {
+ conn, err := pgconn.Connect(ctx, "postgres://127.0.0.1:39933/mydb?user=llama")
+ if !assert.NoError(t, err) {
+ return
+ }
+ defer func() {
+ conn.Close(ctx)
+ }()
+ _, err = conn.Exec(ctx, "SELECT 1;").ReadAll()
+ assert.NoError(t, err)
+ }, 5*time.Second, 100*time.Millisecond)
+
+ // Shut down bot and make sure it exits cleanly.
+ cancelBot()
+ require.NoError(t, <-botCh)
+}
diff --git a/lib/tbot/testhelpers/srv.go b/lib/tbot/testhelpers/srv.go
index 7c8015f5e2a37..b1f15390bbac5 100644
--- a/lib/tbot/testhelpers/srv.go
+++ b/lib/tbot/testhelpers/srv.go
@@ -43,8 +43,12 @@ type DefaultBotConfigOpts struct {
// Makes the bot accept an Insecure auth or proxy server
Insecure bool
+
+ ServiceConfigs botconfig.ServiceConfigs
}
+const AgentJoinToken = "i-am-a-join-token"
+
// DefaultConfig returns a FileConfig to be used in tests, with random listen
// addresses that are tied to the listeners returned in the FileDescriptor
// slice, which should be passed as exported file descriptors to NewTeleport;
@@ -73,6 +77,9 @@ func DefaultConfig(t *testing.T) (*config.FileConfig, []*servicecfg.FileDescript
EnabledFlag: "true",
ListenAddress: testenv.NewTCPListener(t, service.ListenerAuth, &fds),
},
+ StaticTokens: config.StaticTokens{
+ config.StaticToken("db:" + AgentJoinToken),
+ },
},
}
@@ -176,7 +183,11 @@ func MakeBot(t *testing.T, client *auth.Client, name string, roles ...string) *p
// - Uses a memory storage destination
// - Does not verify Proxy WebAPI certificates
func DefaultBotConfig(
- t *testing.T, fc *config.FileConfig, botParams *proto.CreateBotResponse, outputs []botconfig.Output, opts DefaultBotConfigOpts,
+ t *testing.T,
+ fc *config.FileConfig,
+ botParams *proto.CreateBotResponse,
+ outputs []botconfig.Output,
+ opts DefaultBotConfigOpts,
) *botconfig.BotConfig {
t.Helper()
@@ -202,6 +213,7 @@ func DefaultBotConfig(
// Set Insecure so the bot will trust the Proxy's webapi default signed
// certs.
Insecure: opts.Insecure,
+ Services: opts.ServiceConfigs,
}
cfg.Onboarding.SetToken(botParams.TokenID)