diff --git a/api/types/resource.go b/api/types/resource.go
index bf5132cb56a99..544f7b1f92b8c 100644
--- a/api/types/resource.go
+++ b/api/types/resource.go
@@ -17,6 +17,7 @@ limitations under the License.
package types
import (
+ "iter"
"regexp"
"slices"
"sort"
@@ -30,6 +31,7 @@ import (
"github.com/gravitational/teleport/api/types/common"
"github.com/gravitational/teleport/api/types/compare"
"github.com/gravitational/teleport/api/utils"
+ "github.com/gravitational/teleport/api/utils/iterutils"
)
var (
@@ -84,6 +86,12 @@ func GetName[R Resource](r R) string {
return r.GetName()
}
+// ResourceNames creates an iterator that loops through the provided slice of
+// resources and return their names.
+func ResourceNames[R Resource, S ~[]R](s S) iter.Seq[string] {
+ return iterutils.Map(GetName, slices.Values(s))
+}
+
// ResourceDetails includes details about the resource
type ResourceDetails struct {
Hostname string
diff --git a/api/types/resource_test.go b/api/types/resource_test.go
index 53b38ef33e145..f41da33f98c90 100644
--- a/api/types/resource_test.go
+++ b/api/types/resource_test.go
@@ -17,6 +17,8 @@
package types
import (
+ "fmt"
+ "slices"
"testing"
"time"
@@ -820,3 +822,20 @@ func TestResourceHeaderIsEqual(t *testing.T) {
})
}
}
+
+func TestResourceNames(t *testing.T) {
+ var apps Apps
+ var expectedNames []string
+ for i := 0; i < 10; i++ {
+ app, err := NewAppV3(Metadata{
+ Name: fmt.Sprintf("app-%d", i),
+ }, AppSpecV3{
+ URI: "tcp://localhost:1111",
+ })
+ require.NoError(t, err)
+ apps = append(apps, app)
+ expectedNames = append(expectedNames, app.GetName())
+ }
+
+ require.Equal(t, expectedNames, slices.Collect(ResourceNames(apps)))
+}
diff --git a/api/utils/iterutils/iter.go b/api/utils/iterutils/iter.go
new file mode 100644
index 0000000000000..3d0ddf495f43c
--- /dev/null
+++ b/api/utils/iterutils/iter.go
@@ -0,0 +1,37 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package iterutils
+
+import (
+ "iter"
+)
+
+// Map returns an iterator over f applied to seq.
+//
+// Copied from https://github.com/golang/go/issues/61898. We should switch to an
+// official package once it is available.
+func Map[In, Out any](f func(In) Out, seq iter.Seq[In]) iter.Seq[Out] {
+ return func(yield func(Out) bool) {
+ for in := range seq {
+ if !yield(f(in)) {
+ return
+ }
+ }
+ }
+}
diff --git a/api/utils/iterutils/iter_test.go b/api/utils/iterutils/iter_test.go
new file mode 100644
index 0000000000000..e6bf3e3feecd2
--- /dev/null
+++ b/api/utils/iterutils/iter_test.go
@@ -0,0 +1,39 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package iterutils
+
+import (
+ "fmt"
+ "slices"
+ "strings"
+)
+
+func ExampleMap() {
+ inputs := []string{
+ "hello world",
+ "foo",
+ }
+
+ for mapped := range Map(strings.ToUpper, slices.Values(inputs)) {
+ fmt.Println(mapped)
+ }
+ // Output:
+ // HELLO WORLD
+ // FOO
+}
diff --git a/lib/client/db/dbcmd/exec.go b/lib/client/db/dbcmd/exec.go
new file mode 100644
index 0000000000000..e31f25110ca7c
--- /dev/null
+++ b/lib/client/db/dbcmd/exec.go
@@ -0,0 +1,60 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package dbcmd
+
+import (
+ "context"
+ "os/exec"
+
+ "github.com/gravitational/trace"
+
+ "github.com/gravitational/teleport/lib/defaults"
+)
+
+// GetExecCommand returns a command that executes the provided query on the
+// target database using an appropriate CLI database client.
+func (c *CLICommandBuilder) GetExecCommand(_ context.Context, query string) (*exec.Cmd, error) {
+ if !c.options.noTLS || c.options.localProxyHost == "" {
+ return nil, trace.BadParameter("query execution is only supported when using an authenticated local proxy")
+ }
+
+ switch c.db.Protocol {
+ case defaults.ProtocolPostgres:
+ return c.getPostgresExecCommand(query)
+ case defaults.ProtocolMySQL:
+ return c.getMySQLExecCommand(query)
+ default:
+ return nil, trace.BadParameter("%s databases not supported for exec command", c.db.Protocol)
+ }
+}
+
+func (c *CLICommandBuilder) getPostgresExecCommand(query string) (*exec.Cmd, error) {
+ cmd := c.getPostgresCommand()
+ cmd.Args = append(cmd.Args, "-c", query)
+ return cmd, nil
+}
+
+func (c *CLICommandBuilder) getMySQLExecCommand(query string) (*exec.Cmd, error) {
+ cmd, err := c.getMySQLCommand()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ cmd.Args = append(cmd.Args, "-e", query)
+ return cmd, nil
+}
diff --git a/lib/client/db/dbcmd/exec_test.go b/lib/client/db/dbcmd/exec_test.go
new file mode 100644
index 0000000000000..03f8743ef7180
--- /dev/null
+++ b/lib/client/db/dbcmd/exec_test.go
@@ -0,0 +1,120 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package dbcmd
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport/lib/client"
+ "github.com/gravitational/teleport/lib/defaults"
+ "github.com/gravitational/teleport/lib/observability/tracing"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+)
+
+func TestCLICommandBuilderGetExecCommand(t *testing.T) {
+ fakeExec := &fakeExec{
+ execOutput: map[string][]byte{
+ "psql": []byte(""),
+ "mysql": []byte(""),
+ },
+ }
+
+ conf := &client.Config{
+ HomePath: t.TempDir(),
+ Host: "localhost",
+ WebProxyAddr: "proxy.example.com",
+ SiteName: "db.example.com",
+ Tracer: tracing.NoopProvider().Tracer("test"),
+ }
+
+ tc, err := client.NewClient(conf)
+ require.NoError(t, err)
+
+ profile := &client.ProfileStatus{
+ Name: "example.com",
+ Username: "bob",
+ Dir: "/tmp",
+ Cluster: "example.com",
+ }
+
+ tests := []struct {
+ name string
+ opts []ConnectCommandFunc
+ protocol string
+ cmd []string
+ wantErr bool
+ }{
+ {
+ name: "not authenticated tunnel",
+ protocol: defaults.ProtocolPostgres,
+ wantErr: true,
+ },
+ {
+ name: "unsupported protocol",
+ protocol: defaults.ProtocolDynamoDB,
+ opts: []ConnectCommandFunc{WithNoTLS()},
+ wantErr: true,
+ },
+ {
+ name: "postgres",
+ protocol: defaults.ProtocolPostgres,
+ opts: []ConnectCommandFunc{WithNoTLS()},
+ cmd: []string{"psql", "postgres://db-user@localhost:12345/db-name", "-c", "select 1"},
+ wantErr: false,
+ },
+ {
+ name: "mysql",
+ protocol: defaults.ProtocolMySQL,
+ opts: []ConnectCommandFunc{WithNoTLS()},
+ cmd: []string{"mysql", "--user", "db-user", "--database", "db-name", "--port", "12345", "--host", "localhost", "--protocol", "TCP", "-e", "select 1"},
+ wantErr: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ database := tlsca.RouteToDatabase{
+ Protocol: tt.protocol,
+ Database: "db-name",
+ Username: "db-user",
+ ServiceName: "db-service",
+ }
+
+ opts := append([]ConnectCommandFunc{
+ WithLocalProxy("localhost", 12345, ""),
+ WithExecer(fakeExec),
+ }, tt.opts...)
+
+ c := NewCmdBuilder(tc, profile, database, "root", opts...)
+ c.uid = utils.NewFakeUID()
+ got, err := c.GetExecCommand(context.Background(), "select 1")
+ if tt.wantErr {
+ require.Error(t, err)
+ return
+ }
+
+ require.NoError(t, err)
+ require.Equal(t, tt.cmd, got.Args)
+ })
+ }
+}
diff --git a/lib/utils/fs.go b/lib/utils/fs.go
index 0ddb0af13a1d9..6831a4a482b71 100644
--- a/lib/utils/fs.go
+++ b/lib/utils/fs.go
@@ -499,3 +499,13 @@ func RecursiveCopy(src, dest string, skip func(src, dest string) (bool, error))
return nil
}))
}
+
+// CreateExclusiveFile creates a file only if it does not exist to prevent overwriting
+// existing files.
+func CreateExclusiveFile(path string, mode os.FileMode) (*os.File, error) {
+ out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode)
+ if err != nil {
+ return nil, trace.ConvertSystemError(err)
+ }
+ return out, nil
+}
diff --git a/lib/utils/log/slog.go b/lib/utils/log/slog.go
index bfb34f4a94114..67ea8c250dcba 100644
--- a/lib/utils/log/slog.go
+++ b/lib/utils/log/slog.go
@@ -21,8 +21,10 @@ package log
import (
"context"
"fmt"
+ "iter"
"log/slog"
"reflect"
+ "slices"
"strings"
"unicode"
@@ -201,3 +203,19 @@ func (a typeAttr) LogValue() slog.Value {
}
return slog.StringValue("nil")
}
+
+type iterAttr[V any] struct {
+ iter iter.Seq[V]
+}
+
+// IterAttr creates a [slog.LogValuer] that will defer the collection of an
+// iter.Seq.
+func IterAttr[V any](iter iter.Seq[V]) slog.LogValuer {
+ return iterAttr[V]{
+ iter: iter,
+ }
+}
+
+func (a iterAttr[V]) LogValue() slog.Value {
+ return slog.AnyValue(slices.Collect[V](a.iter))
+}
diff --git a/lib/utils/unpack.go b/lib/utils/unpack.go
index 35f76adbb452b..f11d2454e2f13 100644
--- a/lib/utils/unpack.go
+++ b/lib/utils/unpack.go
@@ -181,7 +181,7 @@ func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error {
err := withDir(path, dirMode, func() error {
// Create file only if it does not exist to prevent overwriting existing
// files (like session recordings).
- out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode)
+ out, err := CreateExclusiveFile(path, mode)
if err != nil {
return trace.ConvertSystemError(err)
}
diff --git a/tool/common/common.go b/tool/common/common.go
index a7d9edfe7565f..d0f9ef2ae9664 100644
--- a/tool/common/common.go
+++ b/tool/common/common.go
@@ -206,7 +206,7 @@ func FormatMultiValueLabels(labels map[string][]string, verbose bool) string {
func FormatResourceName(r types.ResourceWithLabels, verbose bool) string {
if !verbose {
// return the (shorter) discovered name in non-verbose mode.
- discoveredName, ok := r.GetAllLabels()[types.DiscoveredNameLabel]
+ discoveredName, ok := GetDiscoveredResourceName(r)
if ok && discoveredName != "" {
return discoveredName
}
@@ -214,6 +214,21 @@ func FormatResourceName(r types.ResourceWithLabels, verbose bool) string {
return r.GetName()
}
+// GetDiscoveredResourceName returns the resource original name discovered in
+// the cloud by the Teleport Discovery Service.
+func GetDiscoveredResourceName(r types.ResourceWithLabels) (discoveredName string, ok bool) {
+ discoveredName, ok = r.GetAllLabels()[types.DiscoveredNameLabel]
+ return
+}
+
+// SetDiscoveredResourceName sets the original name discovered in the cloud by
+// the Teleport Discovery Service.
+func SetDiscoveredResourceName(r types.ResourceWithLabels, discoveredName string) {
+ labels := r.GetStaticLabels()
+ labels[types.DiscoveredNameLabel] = discoveredName
+ r.SetStaticLabels(labels)
+}
+
// FormatDefault formats a zero value with its default, or if the value is not
// zero it just returns the value.
func FormatDefault[T comparable](val, defaultVal T) string {
diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go
index ff5fe4073653a..b0ca1770eea57 100644
--- a/tool/tsh/common/db.go
+++ b/tool/tsh/common/db.go
@@ -652,8 +652,11 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf,
// certificate's DNS names. As such, connecting to 127.0.0.1 will fail
// validation, so connect to localhost.
host := "localhost"
- cmdOpts := []dbcmd.ConnectCommandFunc{
+ cmdOpts, err := makeDatabaseCommandOptions(ctx, tc, dbInfo,
dbcmd.WithLocalProxy(host, addr.Port(0), profile.CACertPathForCluster(rootClusterName)),
+ )
+ if err != nil {
+ return nil, trace.Wrap(err)
}
if requires.tunnel {
cmdOpts = append(cmdOpts, dbcmd.WithNoTLS())
@@ -783,18 +786,6 @@ func onDatabaseConnect(cf *CLIConf) error {
if err != nil {
return trace.Wrap(err)
}
- opts = append(opts,
- dbcmd.WithLogger(logger),
- dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd),
- )
-
- if opts, err = maybeAddDBUserPassword(cf, tc, dbInfo, opts); err != nil {
- return trace.Wrap(err)
- }
- if opts, err = maybeAddGCPMetadata(cf.Context, tc, dbInfo, opts); err != nil {
- return trace.Wrap(err)
- }
- opts = maybeAddOracleOptions(cf.Context, tc, dbInfo, opts)
bb := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootClusterName, opts...)
cmd, err := bb.GetConnectCommand(cf.Context)
@@ -975,22 +966,7 @@ func (d *databaseInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClien
return nil
}
- var clusterClient *client.ClusterClient
- err = client.RetryWithRelogin(cf.Context, tc, func() error {
- clusterClient, err = tc.ConnectToCluster(cf.Context)
- return trace.Wrap(err)
- })
- if err != nil {
- return trace.Wrap(err)
- }
- defer clusterClient.Close()
-
- profile, err := tc.ProfileStatus()
- if err != nil {
- return trace.Wrap(err)
- }
-
- checker, err := services.NewAccessCheckerForRemoteCluster(cf.Context, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient)
+ checker, err := d.getChecker(cf.Context, tc)
if err != nil {
return trace.Wrap(err)
}
@@ -1014,6 +990,30 @@ func (d *databaseInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClien
return nil
}
+func (d *databaseInfo) getChecker(ctx context.Context, tc *client.TeleportClient) (services.AccessChecker, error) {
+ if d.checker != nil {
+ return d.checker, nil
+ }
+ var clusterClient *client.ClusterClient
+ var err error
+ err = client.RetryWithRelogin(ctx, tc, func() error {
+ clusterClient, err = tc.ConnectToCluster(ctx)
+ return trace.Wrap(err)
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ defer clusterClient.Close()
+
+ profile, err := tc.ProfileStatus()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient)
+ return checker, trace.Wrap(err)
+}
+
// databaseInfo wraps a RouteToDatabase and the corresponding database.
// Its purpose is to prevent repeated fetches of the same database, by lazily
// fetching and caching the database for use as needed.
@@ -1024,6 +1024,7 @@ type databaseInfo struct {
database types.Database
// isActive indicates an active database matched this db info.
isActive bool
+ checker services.AccessChecker
mu sync.Mutex
}
diff --git a/tool/tsh/common/db_exec.go b/tool/tsh/common/db_exec.go
new file mode 100644
index 0000000000000..f7fabf02ff6b4
--- /dev/null
+++ b/tool/tsh/common/db_exec.go
@@ -0,0 +1,637 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package common
+
+import (
+ "context"
+ "crypto/tls"
+ "encoding/json"
+ "fmt"
+ "io"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "slices"
+ "strings"
+ "sync"
+
+ "github.com/gravitational/trace"
+ oteltrace "go.opentelemetry.io/otel/trace"
+ "golang.org/x/sync/errgroup"
+
+ "github.com/gravitational/teleport"
+ apiclient "github.com/gravitational/teleport/api/client"
+ "github.com/gravitational/teleport/api/client/proto"
+ apidefaults "github.com/gravitational/teleport/api/defaults"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/client"
+ "github.com/gravitational/teleport/lib/client/db/dbcmd"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/srv/alpnproxy"
+ "github.com/gravitational/teleport/lib/tlsca"
+ "github.com/gravitational/teleport/lib/utils"
+ logutils "github.com/gravitational/teleport/lib/utils/log"
+ "github.com/gravitational/teleport/tool/common"
+)
+
+func onDatabaseExec(cf *CLIConf) error {
+ execCommand, err := newDatabaseExecCommand(cf)
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ defer execCommand.close()
+
+ return trace.Wrap(execCommand.run())
+}
+
+// databaseExecClient is a wrapper of client.TeleportClient that makes backend
+// calls. can be mocked for testing.
+type databaseExecClient interface {
+ close() error
+ getProfileStatus() *client.ProfileStatus
+ getAccessChecker() services.AccessChecker
+ issueCert(context.Context, *databaseInfo) (tls.Certificate, error)
+ listDatabasesWithFilter(context.Context, *proto.ListResourcesRequest) ([]types.Database, error)
+}
+
+type databaseExecCommand struct {
+ cf *CLIConf
+ tc *client.TeleportClient
+ client databaseExecClient
+ makeCommand func(context.Context, *databaseInfo, string, string) (*exec.Cmd, error)
+ summary databaseExecSummary
+}
+
+func newDatabaseExecCommand(cf *CLIConf) (*databaseExecCommand, error) {
+ if err := checkDatabaseExecInputFlags(cf); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ tc, err := makeClient(cf)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ sharedClient, err := newSharedDatabaseExecClient(cf, tc)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ commandMaker, err := newDatabaseExecCommandMaker(cf.Context, tc, sharedClient.getProfileStatus())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return &databaseExecCommand{
+ cf: cf,
+ tc: tc,
+ client: sharedClient,
+ makeCommand: commandMaker.makeCommand,
+ }, nil
+}
+
+func (c *databaseExecCommand) run() error {
+ // Fetch.
+ dbs, err := c.getDatabases()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ // Execute in parallel.
+ group, groupCtx := errgroup.WithContext(c.cf.Context)
+ group.SetLimit(c.cf.ParallelJobs)
+ for _, db := range dbs {
+ group.Go(func() error {
+ result := c.exec(groupCtx, db)
+ c.summary.add(result)
+ return nil
+ })
+ }
+
+ // Print summary.
+ defer func() {
+ switch {
+ case c.cf.OutputDir != "":
+ c.summary.printAndSave(c.cf.Stdout(), c.cf.OutputDir)
+ case len(dbs) > 1:
+ c.summary.print(c.cf.Stdout())
+ }
+ }()
+
+ return trace.Wrap(group.Wait())
+}
+
+func (c *databaseExecCommand) close() {
+ if err := c.client.close(); err != nil {
+ logger.WarnContext(c.cf.Context, "Failed to close client", "error", err)
+ }
+}
+
+func checkDatabaseExecInputFlags(cf *CLIConf) error {
+ // Pick an arbitrary number for max connections to avoid flooding the
+ // backend. The limit can be overwritten with the "TELEPORT_PARALLEL_JOBS"
+ // env var.
+ const maxParallelJobs = 10
+ if cf.ParallelJobs < 1 || cf.ParallelJobs > maxParallelJobs {
+ return trace.BadParameter(`--parallel must be between 1 and %v`, maxParallelJobs)
+ }
+
+ // Selection flags.
+ byNames := cf.DatabaseServices != ""
+ bySearch := cf.SearchKeywords != "" || cf.Labels != ""
+ switch {
+ case !byNames && !bySearch:
+ return trace.BadParameter("please provide one of --dbs, --labels, --search flags")
+ case byNames && bySearch:
+ return trace.BadParameter("--labels/--search flags cannot be used with --dbs flag")
+ }
+
+ // Logging.
+ if cf.ParallelJobs > 1 && cf.OutputDir == "" {
+ return trace.BadParameter("--output-dir must be set when executing concurrent connections")
+ }
+ if cf.OutputDir != "" && utils.FileExists(cf.OutputDir) {
+ return trace.BadParameter("directory %q already exists", cf.OutputDir)
+ }
+ return nil
+}
+
+func (c *databaseExecCommand) getDatabases() ([]types.Database, error) {
+ if c.cf.DatabaseServices != "" {
+ return c.getDatabasesByNames()
+ }
+ return c.searchDatabases()
+}
+
+func (c *databaseExecCommand) getDatabasesByNames() ([]types.Database, error) {
+ // Use a single predicate to search multiple names in one shot but batch 100
+ // names at a time. Extra validation will be performed afterward to ensure
+ // we fetched what we need.
+ fmt.Fprintln(c.cf.Stdout(), "Fetching databases by name ...")
+
+ // Trim spaces.
+ names := stringFlagToStrings(c.cf.DatabaseServices)
+
+ var dbs []types.Database
+ for page := range slices.Chunk(names, 100) {
+ var predicate string
+ for _, name := range page {
+ predicate = makePredicateDisjunction(predicate, makeDiscoveredNameOrNamePredicate(name))
+ }
+
+ logger.DebugContext(c.cf.Context, "Getting database by name", "databases", page)
+ pageDBs, err := c.client.listDatabasesWithFilter(c.cf.Context, &proto.ListResourcesRequest{
+ Namespace: apidefaults.Namespace,
+ ResourceType: types.KindDatabaseServer,
+ PredicateExpression: predicate,
+ })
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ dbs = append(dbs, pageDBs...)
+ }
+
+ logger.DebugContext(c.cf.Context, "Fetched databases by name",
+ "databases", logutils.IterAttr(types.ResourceNames(dbs)))
+ return dbs, trace.Wrap(ensureEachDatabase(names, dbs))
+}
+
+func (c *databaseExecCommand) searchDatabases() (databases []types.Database, err error) {
+ fmt.Fprintln(c.cf.Stdout(), "Searching databases ...")
+ filter := c.tc.ResourceFilter(types.KindDatabaseServer)
+
+ logger.DebugContext(c.cf.Context, "Searching for databases", "filter", filter)
+ dbs, err := c.client.listDatabasesWithFilter(c.cf.Context, filter)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ logger.DebugContext(c.cf.Context, "Fetched databases with search filter",
+ "databases", logutils.IterAttr(types.ResourceNames(dbs)),
+ )
+ return dbs, trace.Wrap(c.printSearchResultAndConfirm(dbs))
+}
+
+func (c *databaseExecCommand) printSearchResultAndConfirm(dbs []types.Database) error {
+ if len(dbs) == 0 {
+ return trace.NotFound("no databases found")
+ }
+
+ fmt.Fprintf(c.cf.Stdout(), "Found %d database(s):\n\n", len(dbs))
+ printTableForDatabaseExec(c.cf.Stdout(), dbs)
+ question := fmt.Sprintf("Do you want to proceed with %d database(s)?", len(dbs))
+ if err := c.cf.PromptConfirmation(question); err != nil {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+func (c *databaseExecCommand) exec(ctx context.Context, db types.Database) (result databaseExecResult) {
+ result = databaseExecResult{
+ RouteToDatabase: client.RouteToDatabaseToProto(c.makeRouteToDatabase(db)),
+ Command: c.cf.DatabaseCommand,
+ Success: true,
+ }
+
+ printErrorAndMakeErrorResult := func(err error) databaseExecResult {
+ fmt.Fprintf(c.cf.Stderr(), "Failed to execute command for %q: %v\n", db.GetName(), err)
+ result.Success = false
+ result.Error = err.Error()
+ return result
+ }
+
+ if ctx.Err() != nil {
+ return printErrorAndMakeErrorResult(ctx.Err())
+ }
+
+ outputWriter := c.cf.Stdout()
+ errWriter := c.cf.Stderr()
+ switch {
+ case c.cf.OutputDir != "":
+ // Use full-name instead of display name for output path.
+ logFile, err := c.openOutputFile(db.GetName())
+ if err != nil {
+ return printErrorAndMakeErrorResult(err)
+ }
+ defer logFile.Close()
+ outputWriter = logFile
+ errWriter = logFile
+ fmt.Fprintf(c.cf.Stdout(), "Executing command for %q. Output will be saved at %q.\n", db.GetName(), logFile.Name())
+
+ // Save absolute path in the summary. Not expecting the absolute check
+ // to fail but use the filename in case it does.
+ if result.OutputFile, err = filepath.Abs(logFile.Name()); err != nil {
+ result.OutputFile = filepath.Base(logFile.Name())
+ }
+ default:
+ // No prefix so output can still be copy-pasted. Extra empty line to
+ // separate sequential executions.
+ fmt.Fprintf(c.cf.Stdout(), "\nExecuting command for %q.\n", db.GetName())
+ }
+
+ var err error
+ result.ExitCode, err = c.runCommand(ctx, db, outputWriter, errWriter)
+ if err != nil {
+ return printErrorAndMakeErrorResult(err)
+ }
+ return result
+}
+
+func (c *databaseExecCommand) runCommand(ctx context.Context, db types.Database, outputWriter, errWriter io.Writer) (int, error) {
+ dbInfo, err := c.makeDatabaseInfo(db)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ lp, err := c.startLocalProxy(ctx, dbInfo)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ defer lp.Close()
+
+ dbCmd, err := c.makeCommand(ctx, dbInfo, lp.GetAddr(), c.cf.DatabaseCommand)
+ if err != nil {
+ return 0, trace.Wrap(err)
+ }
+ dbCmd.Stdout = outputWriter
+ dbCmd.Stderr = errWriter
+
+ logger.DebugContext(ctx, "Executing database command", "command", dbCmd, "db", dbInfo.ServiceName)
+ runErr := c.cf.RunCommand(dbCmd)
+ if dbCmd.ProcessState != nil {
+ return dbCmd.ProcessState.ExitCode(), trace.Wrap(runErr)
+ }
+ return 0, trace.Wrap(runErr)
+}
+
+func (c *databaseExecCommand) startLocalProxy(ctx context.Context, dbInfo *databaseInfo) (*alpnproxy.LocalProxy, error) {
+ // Issue single-use certificate.
+ clientCert, err := c.client.issueCert(ctx, dbInfo)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ // Do not provide a re-issuer middleware to the local proxy. The local proxy
+ // is meant for one-time use so there is no need to re-issue the
+ // certificates.
+ listener, err := createLocalProxyListener("localhost:0", dbInfo.RouteToDatabase, c.client.getProfileStatus())
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ opts := []alpnproxy.LocalProxyConfigOpt{
+ alpnproxy.WithDatabaseProtocol(dbInfo.Protocol),
+ alpnproxy.WithClusterCAsIfConnUpgrade(ctx, c.tc.RootClusterCACertPool),
+ alpnproxy.WithClientCert(clientCert),
+ }
+
+ lpConfig := makeBasicLocalProxyConfig(ctx, c.tc, listener, c.cf.InsecureSkipVerify)
+ lp, err := alpnproxy.NewLocalProxy(lpConfig, opts...)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ go func() {
+ defer listener.Close()
+ if err := lp.Start(ctx); err != nil {
+ logger.ErrorContext(ctx, "Failed to start local proxy", "error", err)
+ }
+ }()
+ return lp, nil
+}
+
+func (c *databaseExecCommand) makeRouteToDatabase(db types.Database) tlsca.RouteToDatabase {
+ return tlsca.RouteToDatabase{
+ ServiceName: db.GetName(),
+ Protocol: db.GetProtocol(),
+ Username: c.cf.DatabaseUser,
+ Database: c.cf.DatabaseName,
+ Roles: requestedDatabaseRoles(c.cf),
+ }
+}
+
+func (c *databaseExecCommand) makeDatabaseInfo(db types.Database) (*databaseInfo, error) {
+ dbInfo := &databaseInfo{
+ RouteToDatabase: c.makeRouteToDatabase(db),
+ database: db,
+ checker: c.client.getAccessChecker(),
+ }
+ return dbInfo, trace.Wrap(dbInfo.checkAndSetDefaults(c.cf, c.tc))
+}
+
+func (c *databaseExecCommand) outputFilename(dbServiceName string) string {
+ return filepath.Join(c.cf.OutputDir, dbServiceName+".output")
+}
+
+func (c *databaseExecCommand) openOutputFile(dbServiceName string) (*os.File, error) {
+ logFilePath, err := utils.EnsureLocalPath(c.outputFilename(dbServiceName), "", "")
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ logFile, err := utils.CreateExclusiveFile(logFilePath, teleport.FileMaskOwnerOnly)
+ return logFile, trace.ConvertSystemError(err)
+}
+
+// sharedDatabaseExecClient is a wrapper of client.TeleportClient that makes
+// backend calls while using a shared ClusterClient.
+type sharedDatabaseExecClient struct {
+ profile *client.ProfileStatus
+ clusterClient *client.ClusterClient
+ accessChecker services.AccessChecker
+ tracer oteltrace.Tracer
+}
+
+func newSharedDatabaseExecClient(cf *CLIConf, tc *client.TeleportClient) (*sharedDatabaseExecClient, error) {
+ var clusterClient *client.ClusterClient
+ var err error
+ if err := client.RetryWithRelogin(cf.Context, tc, func() error {
+ clusterClient, err = tc.ConnectToCluster(cf.Context)
+ return trace.Wrap(err)
+ }); err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ profile, err := tc.ProfileStatus()
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ accessChecker, err := services.NewAccessCheckerForRemoteCluster(cf.Context, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+
+ return &sharedDatabaseExecClient{
+ profile: profile,
+ tracer: cf.TracingProvider.Tracer(teleport.ComponentTSH),
+ clusterClient: clusterClient,
+ accessChecker: accessChecker,
+ }, nil
+}
+
+func (c *sharedDatabaseExecClient) close() error {
+ if err := c.clusterClient.Close(); err != nil && !trace.IsConnectionProblem(err) {
+ return trace.Wrap(err)
+ }
+ return nil
+}
+
+func (c *sharedDatabaseExecClient) getAccessChecker() services.AccessChecker {
+ return c.accessChecker
+}
+
+func (c *sharedDatabaseExecClient) getProfileStatus() *client.ProfileStatus {
+ return c.profile
+}
+
+// issueCert issues a single use cert for the db route.
+func (c *sharedDatabaseExecClient) issueCert(ctx context.Context, dbInfo *databaseInfo) (tls.Certificate, error) {
+ // TODO(greedy52) add support for multi-session MFA.
+ params := client.ReissueParams{
+ RouteToDatabase: client.RouteToDatabaseToProto(dbInfo.RouteToDatabase),
+ AccessRequests: c.profile.ActiveRequests,
+ }
+
+ keyRing, _, err := c.clusterClient.IssueUserCertsWithMFA(ctx, params)
+ if err != nil {
+ return tls.Certificate{}, trace.Wrap(err)
+ }
+ dbCert, err := keyRing.DBTLSCert(dbInfo.RouteToDatabase.ServiceName)
+ return dbCert, trace.Wrap(err)
+}
+
+func (c *sharedDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, filter *proto.ListResourcesRequest) (databases []types.Database, err error) {
+ ctx, span := c.tracer.Start(
+ ctx,
+ "listDatabasesWithFilter",
+ oteltrace.WithSpanKind(oteltrace.SpanKindClient),
+ )
+ defer span.End()
+
+ servers, err := apiclient.GetAllResources[types.DatabaseServer](ctx, c.clusterClient.AuthClient, filter)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return types.DatabaseServers(servers).ToDatabases(), nil
+}
+
+type databaseExecCommandMaker struct {
+ tc *client.TeleportClient
+ profile *client.ProfileStatus
+ rootCluster string
+}
+
+func newDatabaseExecCommandMaker(ctx context.Context, tc *client.TeleportClient, profile *client.ProfileStatus) (*databaseExecCommandMaker, error) {
+ rootCluster, err := tc.RootClusterName(ctx)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return &databaseExecCommandMaker{
+ tc: tc,
+ profile: profile,
+ rootCluster: rootCluster,
+ }, nil
+}
+
+func (m *databaseExecCommandMaker) makeCommand(ctx context.Context, dbInfo *databaseInfo, lpAddr, command string) (*exec.Cmd, error) {
+ addr, err := utils.ParseAddr(lpAddr)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ opts, err := makeDatabaseCommandOptions(ctx, m.tc, dbInfo,
+ dbcmd.WithLocalProxy("localhost", addr.Port(0), ""),
+ dbcmd.WithNoTLS(),
+ )
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return dbcmd.NewCmdBuilder(m.tc, m.profile, dbInfo.RouteToDatabase, m.rootCluster, opts...).
+ GetExecCommand(ctx, command)
+}
+
+// ensureEachDatabase ensures one to one mapping between the provided database
+// target names and database resources.
+//
+// Note that it is assumed that the provided database resource has at least one
+// matching names as they are retrieved from the backend based on one of those
+// names.
+func ensureEachDatabase(names []string, dbs []types.Database) error {
+ byDiscoveredNameOrName := map[string]types.Databases{}
+ for _, db := range dbs {
+ // Database may be listed by their original name in the cloud.
+ byDiscoveredNameOrName[db.GetName()] = append(byDiscoveredNameOrName[db.GetName()], db)
+
+ if discoveredName, ok := common.GetDiscoveredResourceName(db); ok && discoveredName != db.GetName() {
+ byDiscoveredNameOrName[discoveredName] = append(byDiscoveredNameOrName[discoveredName], db)
+ }
+ }
+
+ for _, name := range names {
+ matched := byDiscoveredNameOrName[name]
+ switch len(matched) {
+ case 0:
+ return trace.NotFound("database %q not found", name)
+ case 1:
+ continue
+ default:
+ var sb strings.Builder
+ printTableForDatabaseExec(&sb, matched)
+ return trace.BadParameter(`%q matches multiple databases:
+%vTry selecting the database with a more specific name printed in the above table`, name, sb.String())
+ }
+ }
+
+ return nil
+}
+
+func printTableForDatabaseExec(w io.Writer, dbs []types.Database) {
+ rows := make([]databaseTableRow, 0, len(dbs))
+ for _, db := range dbs {
+ // Always use full name but don't print hidden labels.
+ row := getDatabaseRow("", "", "", db, nil, nil, false)
+ row.DisplayName = db.GetName()
+ rows = append(rows, row)
+ }
+ printDatabaseTable(printDatabaseTableConfig{
+ writer: w,
+ rows: rows,
+ includeColumns: []string{"Name", "Protocol", "Description", "Labels"},
+ })
+}
+
+type databaseExecResult struct {
+ proto.RouteToDatabase `json:"database"`
+ Command string `json:"command"`
+ OutputFile string `json:"output_file,omitempty"`
+ Success bool `json:"success"`
+ Error string `json:"error,omitempty"`
+ ExitCode int `json:"exit_code"`
+}
+
+type databaseExecSummary struct {
+ Databases []databaseExecResult `json:"databases"`
+ Success int `json:"success"`
+ Failure int `json:"failure"`
+ Total int `json:"total"`
+
+ mu sync.Mutex
+}
+
+func (s *databaseExecSummary) add(result databaseExecResult) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.Databases = append(s.Databases, result)
+ s.Total++
+ if result.Success {
+ s.Success++
+ } else {
+ s.Failure++
+ }
+}
+
+func (s *databaseExecSummary) print(w io.Writer) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ fmt.Fprintf(w, "\nSummary: %d of %d succeeded.\n", s.Success, s.Total)
+}
+
+func (s *databaseExecSummary) printAndSave(w io.Writer, outputDir string) {
+ s.print(w)
+ if err := s.save(w, outputDir); err != nil {
+ fmt.Fprintf(w, "Failed to save summary: %v\n", err)
+ }
+}
+
+func (s *databaseExecSummary) save(w io.Writer, outputDir string) error {
+ summaryPath := filepath.Join(outputDir, "summary.json")
+ summaryPath, err := utils.EnsureLocalPath(summaryPath, "", "")
+ if err != nil {
+ return trace.Wrap(err)
+ }
+ summaryFile, err := utils.CreateExclusiveFile(summaryPath, teleport.FileMaskOwnerOnly)
+ if trace.IsAlreadyExists(err) {
+ fmt.Fprintf(w, "Warning: file %s exists and will be overwritten.\n", summaryPath)
+ summaryFile, err = os.OpenFile(summaryPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, teleport.FileMaskOwnerOnly)
+ if err != nil {
+ return trace.ConvertSystemError(err)
+ }
+ } else if err != nil {
+ return trace.Wrap(err)
+ }
+ defer summaryFile.Close()
+
+ summaryData, err := s.makeSummaryJSONData()
+ if err != nil {
+ return trace.Wrap(err)
+ }
+
+ if _, err := summaryFile.Write(summaryData); err != nil {
+ return trace.ConvertSystemError(err)
+ }
+
+ fmt.Fprintf(w, "Summary is saved at %q.\n", summaryPath)
+ return nil
+}
+
+func (s *databaseExecSummary) makeSummaryJSONData() ([]byte, error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ summaryData, err := json.MarshalIndent(s, "", " ")
+ return summaryData, trace.Wrap(err)
+}
diff --git a/tool/tsh/common/db_exec_test.go b/tool/tsh/common/db_exec_test.go
new file mode 100644
index 0000000000000..dc99644fd1b04
--- /dev/null
+++ b/tool/tsh/common/db_exec_test.go
@@ -0,0 +1,525 @@
+/*
+ * Teleport
+ * Copyright (C) 2025 Gravitational, Inc.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see .
+ */
+
+package common
+
+import (
+ "bytes"
+ "context"
+ "crypto/tls"
+ "fmt"
+ "os"
+ "os/exec"
+ "path/filepath"
+ "slices"
+ "strings"
+ "testing"
+
+ "github.com/gravitational/trace"
+ "github.com/stretchr/testify/require"
+
+ "github.com/gravitational/teleport"
+ "github.com/gravitational/teleport/api/client/proto"
+ "github.com/gravitational/teleport/api/types"
+ "github.com/gravitational/teleport/lib/client"
+ "github.com/gravitational/teleport/lib/fixtures"
+ "github.com/gravitational/teleport/lib/services"
+ "github.com/gravitational/teleport/lib/utils"
+ "github.com/gravitational/teleport/tool/common"
+)
+
+func Test_checkDatabaseExecInputFlags(t *testing.T) {
+ dir := t.TempDir()
+
+ tests := []struct {
+ name string
+ cf *CLIConf
+ checkError require.ErrorAssertionFunc
+ }{
+ {
+ name: "with database services",
+ cf: &CLIConf{
+ ParallelJobs: 1,
+ DatabaseServices: "db1,db2",
+ },
+ checkError: require.NoError,
+ },
+ {
+ name: "with search",
+ cf: &CLIConf{
+ ParallelJobs: 1,
+ SearchKeywords: "dev",
+ },
+ checkError: require.NoError,
+ },
+ {
+ name: "with labels",
+ cf: &CLIConf{
+ ParallelJobs: 1,
+ Labels: "env=dev",
+ },
+ checkError: require.NoError,
+ },
+ {
+ name: "invalid max connections",
+ cf: &CLIConf{
+ ParallelJobs: 15,
+ Labels: "env=dev",
+ },
+ checkError: require.Error,
+ },
+ {
+ name: "missing selection",
+ cf: &CLIConf{
+ ParallelJobs: 1,
+ },
+ checkError: require.Error,
+ },
+ {
+ name: "too many selection options",
+ cf: &CLIConf{
+ ParallelJobs: 1,
+ Labels: "env=dev",
+ DatabaseServices: "db1,db2",
+ },
+ checkError: require.Error,
+ },
+ {
+ name: "missing output dir",
+ cf: &CLIConf{
+ ParallelJobs: 5,
+ Labels: "env=dev",
+ },
+ checkError: require.Error,
+ },
+ {
+ name: "output dir exists",
+ cf: &CLIConf{
+ ParallelJobs: 5,
+ OutputDir: dir,
+ Labels: "env=dev",
+ },
+ checkError: require.Error,
+ },
+ {
+ name: "max connections and output dir",
+ cf: &CLIConf{
+ ParallelJobs: 5,
+ OutputDir: filepath.Join(dir, "output"),
+ Labels: "env=dev",
+ },
+ checkError: require.NoError,
+ },
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ tt.checkError(t, checkDatabaseExecInputFlags(tt.cf))
+ })
+ }
+}
+
+func TestDatabaseExec(t *testing.T) {
+ // Populate fake client.
+ cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM))
+ require.NoError(t, err)
+
+ fakeClient := &fakeDatabaseExecClient{
+ cert: cert,
+ allDatabaseServers: []types.DatabaseServer{
+ mustMakeDatabaseServerForEnv(t, "pg1", types.DatabaseProtocolPostgreSQL, "dev"),
+ mustMakeDatabaseServerForEnv(t, "pg2", types.DatabaseProtocolPostgreSQL, "dev"),
+ mustMakeDatabaseServerForEnv(t, "pg3", types.DatabaseProtocolPostgreSQL, "prod"),
+ mustMakeDatabaseServerForEnv(t, "mysql", types.DatabaseProtocolMySQL, "prod"),
+ mustMakeDatabaseServerForEnv(t, "mongo", types.DatabaseProtocolMongoDB, "staging"),
+ },
+ }
+
+ // Commands are not actually being run but passed to cf.RunCommand.
+ // Here just passing the query through the command for verification.
+ dbQuery := "db-query"
+ makeCommand := func(_ context.Context, dbInfo *databaseInfo, _ string, dbQuery string) (*exec.Cmd, error) {
+ return exec.Command(dbQuery), nil
+ }
+ verifyDBQuery := func(cmd *exec.Cmd) error {
+ if !slices.Equal(cmd.Args, []string{dbQuery}) {
+ return trace.CompareFailed("expect %q but got %q", dbQuery, cmd.Args)
+ }
+ fmt.Fprintln(cmd.Stdout, dbQuery, "executed")
+ return nil
+ }
+
+ tests := []struct {
+ name string
+ setup func(*testing.T, *databaseExecCommand)
+ wantError string
+ expectOutputContains []string
+ verifyDir func(t *testing.T, dir string)
+ }{
+ {
+ name: "no databases found by search",
+ setup: func(_ *testing.T, cmd *databaseExecCommand) {
+ cmd.cf.SearchKeywords = "not-found"
+ },
+ wantError: "no databases found",
+ },
+ {
+ name: "no databases found by names",
+ setup: func(_ *testing.T, cmd *databaseExecCommand) {
+ cmd.cf.DatabaseServices = "not-found"
+ },
+ wantError: "not found",
+ },
+ {
+ name: "by names",
+ setup: func(_ *testing.T, cmd *databaseExecCommand) {
+ cmd.cf.DatabaseServices = "pg1,pg2,pg3"
+ },
+ expectOutputContains: []string{
+ "Fetching databases by name",
+ "Executing command for \"pg1\".",
+ "db-query executed",
+ "Executing command for \"pg2\".",
+ "Executing command for \"pg3\".",
+ "Summary:",
+ },
+ },
+ {
+ name: "by keyword",
+ setup: func(_ *testing.T, cmd *databaseExecCommand) {
+ cmd.cf.SearchKeywords = "mysql"
+ },
+ expectOutputContains: []string{
+ "Found 1 database(s)",
+ "Name Description Protocol Labels",
+ "----- ----------- -------- --------",
+ "mysql mysql env=prod",
+ "Executing command for \"mysql\".",
+ "db-query executed",
+ },
+ },
+ {
+ name: "by env",
+ setup: func(_ *testing.T, cmd *databaseExecCommand) {
+ cmd.cf.Labels = "env=dev"
+ },
+ expectOutputContains: []string{
+ "Found 2 database(s)",
+ "Name Description Protocol Labels",
+ "---- ----------- -------- -------",
+ "pg1 postgres env=dev",
+ "pg2 postgres env=dev",
+ "Executing command for \"pg1\".",
+ "db-query executed",
+ "Executing command for \"pg2\".",
+ "Summary:",
+ },
+ },
+ {
+ name: "output dir",
+ setup: func(_ *testing.T, cmd *databaseExecCommand) {
+ cmd.cf.DatabaseServices = "pg3,mysql"
+ cmd.cf.OutputDir = filepath.Join(cmd.cf.HomePath, "test-output")
+ },
+ expectOutputContains: []string{
+ "Fetching databases by name",
+ "Executing command for \"pg3\". Output will be saved at",
+ "Executing command for \"mysql\". Output will be saved at",
+ "Summary:",
+ "Summary is saved",
+ },
+ verifyDir: func(t *testing.T, dir string) {
+ t.Helper()
+ read, err := utils.ReadPath(filepath.Join(dir, "test-output", "pg3.output"))
+ require.NoError(t, err)
+ require.Equal(t, "db-query executed", strings.TrimSpace(string(read)))
+ read, err = utils.ReadPath(filepath.Join(dir, "test-output", "mysql.output"))
+ require.NoError(t, err)
+ require.Equal(t, "db-query executed", strings.TrimSpace(string(read)))
+ require.True(t, utils.FileExists(filepath.Join(dir, "test-output", "summary.json")))
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ dir := t.TempDir()
+
+ // Prep CLIConf.
+ var capture bytes.Buffer
+ writer := utils.NewSyncWriter(&capture)
+ cf := &CLIConf{
+ Proxy: "proxy:3080",
+ Context: context.Background(),
+ HomePath: dir,
+ ParallelJobs: 1,
+ DatabaseUser: "db-user",
+ DatabaseName: "db-name",
+ DatabaseCommand: dbQuery,
+ cmdRunner: verifyDBQuery,
+ OverrideStdout: writer,
+ overrideStderr: writer,
+ Confirm: false,
+ }
+
+ // Prep command and sanity check.
+ c := &databaseExecCommand{
+ cf: cf,
+ client: fakeClient,
+ makeCommand: makeCommand,
+ }
+ tt.setup(t, c)
+ mustCreateEmptyProfile(t, cf)
+ c.tc, err = makeClient(cf)
+ require.NoError(t, err)
+ require.NoError(t, checkDatabaseExecInputFlags(c.cf))
+
+ runError := c.run()
+ if tt.wantError != "" {
+ require.Error(t, runError)
+ require.Contains(t, runError.Error(), tt.wantError)
+ return
+ }
+
+ output := capture.String()
+ for _, expect := range tt.expectOutputContains {
+ require.Contains(t, output, expect)
+ }
+
+ if tt.verifyDir != nil {
+ tt.verifyDir(t, dir)
+ }
+ })
+ }
+
+}
+
+type fakeDatabaseExecClient struct {
+ cert tls.Certificate
+ allDatabaseServers []types.DatabaseServer
+}
+
+func (c *fakeDatabaseExecClient) close() error {
+ return nil
+}
+func (c *fakeDatabaseExecClient) getProfileStatus() *client.ProfileStatus {
+ return &client.ProfileStatus{}
+}
+func (c *fakeDatabaseExecClient) getAccessChecker() services.AccessChecker {
+ return services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, "clustername", services.NewRoleSet())
+}
+func (c *fakeDatabaseExecClient) issueCert(context.Context, *databaseInfo) (tls.Certificate, error) {
+ return c.cert, nil
+}
+func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, req *proto.ListResourcesRequest) ([]types.Database, error) {
+ filter := services.MatchResourceFilter{
+ ResourceKind: req.ResourceType,
+ Labels: req.Labels,
+ SearchKeywords: req.SearchKeywords,
+ }
+ if req.PredicateExpression != "" {
+ expression, err := services.NewResourceExpression(req.PredicateExpression)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ }
+ filter.PredicateExpression = expression
+ }
+
+ var filtered []types.Database
+ for _, dbServer := range c.allDatabaseServers {
+ match, err := services.MatchResourceByFilters(dbServer, filter, nil)
+ if err != nil {
+ return nil, trace.Wrap(err)
+ } else if match {
+ filtered = append(filtered, dbServer.GetDatabase())
+ }
+ }
+ return filtered, nil
+}
+
+func mustMakeDatabaseServer(t *testing.T, db types.Database) types.DatabaseServer {
+ t.Helper()
+
+ dbV3, ok := db.(*types.DatabaseV3)
+ require.True(t, ok)
+
+ server, err := types.NewDatabaseServerV3(types.Metadata{
+ Name: db.GetName(),
+ }, types.DatabaseServerSpecV3{
+ Version: teleport.Version,
+ Hostname: "hostname",
+ HostID: "host-id",
+ Database: dbV3,
+ ProxyIDs: []string{"proxy"},
+ })
+ require.NoError(t, err)
+ return server
+}
+
+func mustMakeDatabaseForEnv(t *testing.T, name, protocol, env string) types.Database {
+ t.Helper()
+ db, err := types.NewDatabaseV3(
+ types.Metadata{
+ Name: name,
+ Labels: map[string]string{"env": env},
+ },
+ types.DatabaseSpecV3{
+ Protocol: protocol,
+ URI: "localhost:12345",
+ },
+ )
+ require.NoError(t, err)
+ return db
+}
+
+func mustMakeDatabaseServerForEnv(t *testing.T, name, protocol, env string) types.DatabaseServer {
+ t.Helper()
+ db := mustMakeDatabaseForEnv(t, name, protocol, env)
+ return mustMakeDatabaseServer(t, db)
+}
+
+func Test_ensureEachDatabase(t *testing.T) {
+ devDB := mustMakeDatabaseForEnv(t, "dev", "postgres", "dev")
+ stagingDB := mustMakeDatabaseForEnv(t, "staging", "postgres", "staging")
+ prodDB1 := mustMakeDatabaseForEnv(t, "prod", "postgres", "prod")
+ prodDB2 := mustMakeDatabaseForEnv(t, "prod", "postgres", "prod")
+ common.SetDiscoveredResourceName(stagingDB, "staging") // edge case where discovered name is the same.
+ common.SetDiscoveredResourceName(prodDB1, "prod-cloud1")
+ common.SetDiscoveredResourceName(prodDB2, "prod-cloud2")
+
+ tests := []struct {
+ name string
+ inputNames []string
+ inputDatabases []types.Database
+ expectErrorContains string
+ }{
+ {
+ name: "exact match",
+ inputNames: []string{"dev", "staging"},
+ inputDatabases: []types.Database{devDB, stagingDB},
+ },
+ {
+ name: "discovered name match",
+ inputNames: []string{"prod-cloud1", "prod-cloud2", "dev"},
+ inputDatabases: []types.Database{devDB, prodDB1, prodDB2},
+ },
+ {
+ name: "database not found",
+ inputNames: []string{"dev", "staging", "prod-cloud5"},
+ inputDatabases: []types.Database{devDB, stagingDB},
+ expectErrorContains: "\"prod-cloud5\" not found",
+ },
+ {
+ name: "ambiguous name",
+ inputNames: []string{"prod"},
+ inputDatabases: []types.Database{prodDB1, prodDB2},
+ expectErrorContains: "\"prod\" matches multiple databases",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ err := ensureEachDatabase(tt.inputNames, tt.inputDatabases)
+ if tt.expectErrorContains == "" {
+ require.NoError(t, err)
+ } else {
+ require.Error(t, err)
+ require.Contains(t, err.Error(), tt.expectErrorContains)
+ }
+ })
+ }
+}
+
+func Test_databaseExecSummary(t *testing.T) {
+ summary := databaseExecSummary{}
+ summary.add(databaseExecResult{
+ RouteToDatabase: proto.RouteToDatabase{
+ ServiceName: "db1",
+ Protocol: "postgres",
+ Username: "db-user",
+ },
+ Success: true,
+ })
+ summary.add(databaseExecResult{
+ RouteToDatabase: proto.RouteToDatabase{
+ ServiceName: "db2",
+ Protocol: "postgres",
+ Username: "db-user",
+ },
+ Error: "some error",
+ ExitCode: 1,
+ })
+ summary.add(databaseExecResult{
+ RouteToDatabase: proto.RouteToDatabase{
+ ServiceName: "db3",
+ Protocol: "postgres",
+ Username: "db-user",
+ },
+ Success: true,
+ })
+
+ var buf bytes.Buffer
+ summary.print(&buf)
+ require.Contains(t, buf.String(), "Summary: 2 of 3 succeeded")
+
+ buf.Reset()
+ dir := t.TempDir()
+ expectPath := filepath.Join(dir, "summary.json")
+ summary.printAndSave(&buf, dir)
+ require.Contains(t, buf.String(), "Summary: 2 of 3 succeeded")
+ require.Contains(t, buf.String(), fmt.Sprintf("Summary is saved at %q", expectPath))
+ summaryData, err := os.ReadFile(expectPath)
+ require.NoError(t, err)
+ require.Equal(t, `{
+ "databases": [
+ {
+ "database": {
+ "service_name": "db1",
+ "protocol": "postgres",
+ "username": "db-user"
+ },
+ "command": "",
+ "success": true,
+ "exit_code": 0
+ },
+ {
+ "database": {
+ "service_name": "db2",
+ "protocol": "postgres",
+ "username": "db-user"
+ },
+ "command": "",
+ "success": false,
+ "error": "some error",
+ "exit_code": 1
+ },
+ {
+ "database": {
+ "service_name": "db3",
+ "protocol": "postgres",
+ "username": "db-user"
+ },
+ "command": "",
+ "success": true,
+ "exit_code": 0
+ }
+ ],
+ "success": 2,
+ "failure": 1,
+ "total": 3
+}`, string(summaryData))
+}
diff --git a/tool/tsh/common/db_print.go b/tool/tsh/common/db_print.go
index c79a44d5dec64..b86e856584fba 100644
--- a/tool/tsh/common/db_print.go
+++ b/tool/tsh/common/db_print.go
@@ -79,9 +79,20 @@ type printDatabaseTableConfig struct {
rows []databaseTableRow
showProxyAndCluster bool
verbose bool
+ // includeColumns specifies a whitelist of columns to include. verbose and
+ // showProxyAndCluster are ignored when includeColumns is provided.
+ includeColumns []string
}
-func (cfg printDatabaseTableConfig) excludeColumns() (out []string) {
+func (cfg printDatabaseTableConfig) excludeColumns(allColumns []string) (out []string) {
+ if len(cfg.includeColumns) > 0 {
+ for _, column := range allColumns {
+ if !slices.Contains(cfg.includeColumns, column) {
+ out = append(out, column)
+ }
+ }
+ return
+ }
if !cfg.showProxyAndCluster {
out = append(out, "Proxy", "Cluster")
}
@@ -94,7 +105,7 @@ func (cfg printDatabaseTableConfig) excludeColumns() (out []string) {
func printDatabaseTable(cfg printDatabaseTableConfig) {
allColumns := makeTableColumnTitles(databaseTableRow{})
rowsWithAllColumns := makeTableRows(cfg.rows)
- excludeColumns := cfg.excludeColumns()
+ excludeColumns := cfg.excludeColumns(allColumns)
var printColumns []string
printRows := make([][]string, len(cfg.rows))
diff --git a/tool/tsh/common/db_print_test.go b/tool/tsh/common/db_print_test.go
index 23593cb6f2b48..22ca4b6a9a8f8 100644
--- a/tool/tsh/common/db_print_test.go
+++ b/tool/tsh/common/db_print_test.go
@@ -106,6 +106,19 @@ db2 describe db2 mysql self-hosted localhost:3306 [alice] [readonly]
proxy cluster1 db1 describe db1 postgres self-hosted localhost:5432 [*] Env=dev tsh db connect db1
proxy cluster1 db2 describe db2 mysql self-hosted localhost:3306 [alice] [readonly] Env=prod
+`,
+ },
+ {
+ name: "tsh db exec search results",
+ cfg: printDatabaseTableConfig{
+ rows: rows,
+ includeColumns: []string{"Name", "Protocol", "Description", "Labels"},
+ },
+ expect: `Name Description Protocol Labels
+---- ------------ -------- --------
+db1 describe db1 postgres Env=dev
+db2 describe db2 mysql Env=prod
+
`,
},
}
diff --git a/tool/tsh/common/git_list_test.go b/tool/tsh/common/git_list_test.go
index cf4d52e318043..15a40a625e94f 100644
--- a/tool/tsh/common/git_list_test.go
+++ b/tool/tsh/common/git_list_test.go
@@ -128,21 +128,14 @@ func TestGitListCommand(t *testing.T) {
}
// Create a empty profile so we don't ping proxy.
- clientStore, err := initClientStore(cf, cf.Proxy)
- require.NoError(t, err)
- profile := &profile.Profile{
- SSHProxyAddr: "proxy:3023",
- WebProxyAddr: "proxy:3080",
- }
- err = clientStore.SaveProfile(profile, true)
- require.NoError(t, err)
+ mustCreateEmptyProfile(t, cf)
cmd := gitListCommand{
format: test.format,
fetchFn: test.fetchFn,
}
- err = cmd.run(cf)
+ err := cmd.run(cf)
if test.wantError {
require.Error(t, err)
} else {
@@ -154,3 +147,21 @@ func TestGitListCommand(t *testing.T) {
})
}
}
+
+// mustCreateEmptyProfile creates an empty profile so we don't ping proxy when
+// calling makeClient on provided cf.
+func mustCreateEmptyProfile(t *testing.T, cf *CLIConf) {
+ t.Helper()
+
+ if cf.HomePath == "" {
+ cf.HomePath = t.TempDir()
+ }
+
+ clientStore, err := initClientStore(cf, cf.Proxy)
+ require.NoError(t, err)
+ err = clientStore.SaveProfile(&profile.Profile{
+ SSHProxyAddr: cf.Proxy,
+ WebProxyAddr: cf.Proxy,
+ }, true)
+ require.NoError(t, err)
+}
diff --git a/tool/tsh/common/help.go b/tool/tsh/common/help.go
index 427a3be93702d..71ac21769b1db 100644
--- a/tool/tsh/common/help.go
+++ b/tool/tsh/common/help.go
@@ -62,4 +62,18 @@ Examples:
Get database names using "jq":
$ tsh db ls --format json | jq -r '.[].metadata.name'`
+
+ dbExecHelp = `
+Examples:
+ Search databases with labels:
+ $ tsh db exec "source my_script.sql" --db-user mysql --labels key1=value1,key2=value2
+
+ Search databases with keywords:
+ $ tsh db exec "select 1" --db-user mysql --db-name mysql --search foo,bar
+
+ Execute a command on specified target databases without confirmation:
+ $ tsh db exec "select @@hostname" --db-user mysql --dbs mydb1,mydb2,mydb3 --no-confirm
+
+ Run commands in parallel, and save outputs to files:
+ $ tsh db exec "select 1" --db-user mysql --labels env=dev --parallel=5 --output-dir=exec-outputs`
)
diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go
index 45a112a254ad3..f1faf9bc57fae 100644
--- a/tool/tsh/common/proxy.go
+++ b/tool/tsh/common/proxy.go
@@ -216,21 +216,15 @@ func onProxyCommandDB(cf *CLIConf) error {
if err != nil {
return trace.Wrap(err)
}
- opts := []dbcmd.ConnectCommandFunc{
+ opts, err := makeDatabaseCommandOptions(cf.Context, tc, dbInfo,
dbcmd.WithLocalProxy("localhost", addr.Port(0), ""),
dbcmd.WithNoTLS(),
- dbcmd.WithLogger(logger),
dbcmd.WithPrintFormat(),
dbcmd.WithTolerateMissingCLIClient(),
- dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd),
- }
- if opts, err = maybeAddDBUserPassword(cf, tc, dbInfo, opts); err != nil {
- return trace.Wrap(err)
- }
- if opts, err = maybeAddGCPMetadata(cf.Context, tc, dbInfo, opts); err != nil {
+ )
+ if err != nil {
return trace.Wrap(err)
}
- opts = maybeAddOracleOptions(cf.Context, tc, dbInfo, opts)
commands, err := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootCluster,
opts...,
@@ -278,9 +272,9 @@ func onProxyCommandDB(cf *CLIConf) error {
return nil
}
-func maybeAddDBUserPassword(cf *CLIConf, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) {
+func maybeAddDBUserPassword(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) {
if dbInfo.Protocol == defaults.ProtocolCassandra {
- db, err := dbInfo.GetDatabase(cf.Context, tc)
+ db, err := dbInfo.GetDatabase(ctx, tc)
if err != nil {
return nil, trace.Wrap(err)
}
@@ -302,6 +296,22 @@ func requiresGCPMetadata(protocol string) bool {
return protocol == defaults.ProtocolSpanner
}
+func makeDatabaseCommandOptions(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, extraOpts ...dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) {
+ var err error
+ opts := append([]dbcmd.ConnectCommandFunc{
+ dbcmd.WithLogger(logger),
+ dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd),
+ }, extraOpts...)
+
+ if opts, err = maybeAddDBUserPassword(ctx, tc, dbInfo, opts); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ if opts, err = maybeAddGCPMetadata(ctx, tc, dbInfo, opts); err != nil {
+ return nil, trace.Wrap(err)
+ }
+ return maybeAddOracleOptions(ctx, tc, dbInfo, opts), nil
+}
+
func maybeAddGCPMetadata(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) {
if !requiresGCPMetadata(dbInfo.Protocol) {
return opts, nil
diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go
index 5dd1a298259e2..1107984f777f0 100644
--- a/tool/tsh/common/tsh.go
+++ b/tool/tsh/common/tsh.go
@@ -67,6 +67,7 @@ import (
"github.com/gravitational/teleport/api/types/accesslist"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/types/wrappers"
+ apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/keys/hardwarekey"
"github.com/gravitational/teleport/api/utils/keys/piv"
"github.com/gravitational/teleport/api/utils/prompt"
@@ -241,12 +242,17 @@ type CLIConf struct {
// DatabaseService specifies the database proxy server to log into.
DatabaseService string
+ // DatabaseServices specifies a list of database services.
+ DatabaseServices string
// DatabaseUser specifies database user to embed in the certificate.
DatabaseUser string
// DatabaseName specifies database name to embed in the certificate.
DatabaseName string
// DatabaseRoles specifies database roles to embed in the certificate.
DatabaseRoles string
+ // DatabaseCommand specifies the command to execute.
+ DatabaseCommand string
+
// AppName specifies proxied application name.
AppName string
// Interactive sessions will allocate a PTY and create interactive "shell"
@@ -586,6 +592,13 @@ type CLIConf struct {
// to the hardware key agent. Some commands, like login, are better with the
// direct PIV service so that prompts are not split between processes.
disableHardwareKeyAgentClient bool
+
+ // ParallelJobs specifies the number of parallel jobs allowed.
+ ParallelJobs int
+ // OutputDir specifies the directory for storing command outputs.
+ OutputDir string
+ // Confirm determines whether to provide a y/N confirmation prompt.
+ Confirm bool
}
// Stdout returns the stdout writer.
@@ -633,6 +646,23 @@ func (c *CLIConf) LookPath(file string) (string, error) {
return exec.LookPath(file)
}
+// PromptConfirmation prompts the user for a yes/no confirmation for question.
+// The prompt is skipped unless cf.Confirm is set.
+func (c *CLIConf) PromptConfirmation(question string) error {
+ if !c.Confirm {
+ fmt.Fprintf(c.Stdout(), "Skipping confirmation for %q due to the --no-confirm flag.\n", question)
+ return nil
+ }
+
+ ok, err := prompt.Confirmation(c.Context, c.Stdout(), prompt.Stdin(), question)
+ if err != nil {
+ return trace.Wrap(err)
+ } else if !ok {
+ return trace.Errorf("Operation canceled by user request.")
+ }
+ return nil
+}
+
func Main() {
cmdLineOrig := os.Args[1:]
var cmdLine []string
@@ -1030,6 +1060,18 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason)
dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest)
dbConnect.Flag("tunnel", "Open authenticated tunnel using database's client certificate so clients don't need to authenticate").Hidden().BoolVar(&cf.LocalProxyTunnel)
+ dbExec := db.Command("exec", "Execute database commands on target database services.")
+ dbExec.Flag("db-user", "Database user to log in as.").Short('u').StringVar(&cf.DatabaseUser)
+ dbExec.Flag("db-name", "Database name to log in to.").Short('n').StringVar(&cf.DatabaseName)
+ dbExec.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles)
+ dbExec.Flag("search", searchHelp).StringVar(&cf.SearchKeywords)
+ dbExec.Flag("labels", labelHelp).StringVar(&cf.Labels)
+ dbExec.Flag("parallel", "Run commands on target databases in parallel. Defaults to 1, and maximum allowed is 10.").Default("1").IntVar(&cf.ParallelJobs)
+ dbExec.Flag("output-dir", "Directory to store command output per target database service. A summary is saved as \"summary.json\".").StringVar(&cf.OutputDir)
+ dbExec.Flag("dbs", "List of comma separated target database services. Mutually exclusive with --search or --labels.").StringVar(&cf.DatabaseServices)
+ dbExec.Flag("confirm", "Confirm selected database services before executing command.").Default("true").BoolVar(&cf.Confirm)
+ dbExec.Arg("command", "Execute this command on target database services.").Required().StringVar(&cf.DatabaseCommand)
+ dbExec.Alias(dbExecHelp)
// join
join := app.Command("join", "Join the active SSH or Kubernetes session.")
@@ -1592,6 +1634,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error {
err = onDatabaseConfig(&cf)
case dbConnect.FullCommand():
err = onDatabaseConnect(&cf)
+ case dbExec.FullCommand():
+ err = onDatabaseExec(&cf)
case environment.FullCommand():
err = onEnvironment(&cf)
case mfa.ls.FullCommand():
@@ -5867,3 +5911,14 @@ func tryLockMemory(cf *CLIConf) error {
return trace.BadParameter("unexpected value for --mlock, expected one of (%v)", strings.Join(mlockModes, ", "))
}
}
+
+// stringFlagToStrings parses a comma-separated string from a CLIConf flag into
+// a slice of strings. It trims whitespace from each value and removes
+// duplicates.
+func stringFlagToStrings(value string) []string {
+ values := strings.Split(value, ",")
+ for i := range values {
+ values[i] = strings.TrimSpace(values[i])
+ }
+ return apiutils.Deduplicate(values)
+}