diff --git a/api/types/resource.go b/api/types/resource.go index bf5132cb56a99..544f7b1f92b8c 100644 --- a/api/types/resource.go +++ b/api/types/resource.go @@ -17,6 +17,7 @@ limitations under the License. package types import ( + "iter" "regexp" "slices" "sort" @@ -30,6 +31,7 @@ import ( "github.com/gravitational/teleport/api/types/common" "github.com/gravitational/teleport/api/types/compare" "github.com/gravitational/teleport/api/utils" + "github.com/gravitational/teleport/api/utils/iterutils" ) var ( @@ -84,6 +86,12 @@ func GetName[R Resource](r R) string { return r.GetName() } +// ResourceNames creates an iterator that loops through the provided slice of +// resources and return their names. +func ResourceNames[R Resource, S ~[]R](s S) iter.Seq[string] { + return iterutils.Map(GetName, slices.Values(s)) +} + // ResourceDetails includes details about the resource type ResourceDetails struct { Hostname string diff --git a/api/types/resource_test.go b/api/types/resource_test.go index 53b38ef33e145..f41da33f98c90 100644 --- a/api/types/resource_test.go +++ b/api/types/resource_test.go @@ -17,6 +17,8 @@ package types import ( + "fmt" + "slices" "testing" "time" @@ -820,3 +822,20 @@ func TestResourceHeaderIsEqual(t *testing.T) { }) } } + +func TestResourceNames(t *testing.T) { + var apps Apps + var expectedNames []string + for i := 0; i < 10; i++ { + app, err := NewAppV3(Metadata{ + Name: fmt.Sprintf("app-%d", i), + }, AppSpecV3{ + URI: "tcp://localhost:1111", + }) + require.NoError(t, err) + apps = append(apps, app) + expectedNames = append(expectedNames, app.GetName()) + } + + require.Equal(t, expectedNames, slices.Collect(ResourceNames(apps))) +} diff --git a/api/utils/iterutils/iter.go b/api/utils/iterutils/iter.go new file mode 100644 index 0000000000000..3d0ddf495f43c --- /dev/null +++ b/api/utils/iterutils/iter.go @@ -0,0 +1,37 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package iterutils + +import ( + "iter" +) + +// Map returns an iterator over f applied to seq. +// +// Copied from https://github.com/golang/go/issues/61898. We should switch to an +// official package once it is available. +func Map[In, Out any](f func(In) Out, seq iter.Seq[In]) iter.Seq[Out] { + return func(yield func(Out) bool) { + for in := range seq { + if !yield(f(in)) { + return + } + } + } +} diff --git a/api/utils/iterutils/iter_test.go b/api/utils/iterutils/iter_test.go new file mode 100644 index 0000000000000..e6bf3e3feecd2 --- /dev/null +++ b/api/utils/iterutils/iter_test.go @@ -0,0 +1,39 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package iterutils + +import ( + "fmt" + "slices" + "strings" +) + +func ExampleMap() { + inputs := []string{ + "hello world", + "foo", + } + + for mapped := range Map(strings.ToUpper, slices.Values(inputs)) { + fmt.Println(mapped) + } + // Output: + // HELLO WORLD + // FOO +} diff --git a/lib/client/db/dbcmd/exec.go b/lib/client/db/dbcmd/exec.go new file mode 100644 index 0000000000000..e31f25110ca7c --- /dev/null +++ b/lib/client/db/dbcmd/exec.go @@ -0,0 +1,60 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package dbcmd + +import ( + "context" + "os/exec" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/lib/defaults" +) + +// GetExecCommand returns a command that executes the provided query on the +// target database using an appropriate CLI database client. +func (c *CLICommandBuilder) GetExecCommand(_ context.Context, query string) (*exec.Cmd, error) { + if !c.options.noTLS || c.options.localProxyHost == "" { + return nil, trace.BadParameter("query execution is only supported when using an authenticated local proxy") + } + + switch c.db.Protocol { + case defaults.ProtocolPostgres: + return c.getPostgresExecCommand(query) + case defaults.ProtocolMySQL: + return c.getMySQLExecCommand(query) + default: + return nil, trace.BadParameter("%s databases not supported for exec command", c.db.Protocol) + } +} + +func (c *CLICommandBuilder) getPostgresExecCommand(query string) (*exec.Cmd, error) { + cmd := c.getPostgresCommand() + cmd.Args = append(cmd.Args, "-c", query) + return cmd, nil +} + +func (c *CLICommandBuilder) getMySQLExecCommand(query string) (*exec.Cmd, error) { + cmd, err := c.getMySQLCommand() + if err != nil { + return nil, trace.Wrap(err) + } + cmd.Args = append(cmd.Args, "-e", query) + return cmd, nil +} diff --git a/lib/client/db/dbcmd/exec_test.go b/lib/client/db/dbcmd/exec_test.go new file mode 100644 index 0000000000000..03f8743ef7180 --- /dev/null +++ b/lib/client/db/dbcmd/exec_test.go @@ -0,0 +1,120 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package dbcmd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" +) + +func TestCLICommandBuilderGetExecCommand(t *testing.T) { + fakeExec := &fakeExec{ + execOutput: map[string][]byte{ + "psql": []byte(""), + "mysql": []byte(""), + }, + } + + conf := &client.Config{ + HomePath: t.TempDir(), + Host: "localhost", + WebProxyAddr: "proxy.example.com", + SiteName: "db.example.com", + Tracer: tracing.NoopProvider().Tracer("test"), + } + + tc, err := client.NewClient(conf) + require.NoError(t, err) + + profile := &client.ProfileStatus{ + Name: "example.com", + Username: "bob", + Dir: "/tmp", + Cluster: "example.com", + } + + tests := []struct { + name string + opts []ConnectCommandFunc + protocol string + cmd []string + wantErr bool + }{ + { + name: "not authenticated tunnel", + protocol: defaults.ProtocolPostgres, + wantErr: true, + }, + { + name: "unsupported protocol", + protocol: defaults.ProtocolDynamoDB, + opts: []ConnectCommandFunc{WithNoTLS()}, + wantErr: true, + }, + { + name: "postgres", + protocol: defaults.ProtocolPostgres, + opts: []ConnectCommandFunc{WithNoTLS()}, + cmd: []string{"psql", "postgres://db-user@localhost:12345/db-name", "-c", "select 1"}, + wantErr: false, + }, + { + name: "mysql", + protocol: defaults.ProtocolMySQL, + opts: []ConnectCommandFunc{WithNoTLS()}, + cmd: []string{"mysql", "--user", "db-user", "--database", "db-name", "--port", "12345", "--host", "localhost", "--protocol", "TCP", "-e", "select 1"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + database := tlsca.RouteToDatabase{ + Protocol: tt.protocol, + Database: "db-name", + Username: "db-user", + ServiceName: "db-service", + } + + opts := append([]ConnectCommandFunc{ + WithLocalProxy("localhost", 12345, ""), + WithExecer(fakeExec), + }, tt.opts...) + + c := NewCmdBuilder(tc, profile, database, "root", opts...) + c.uid = utils.NewFakeUID() + got, err := c.GetExecCommand(context.Background(), "select 1") + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.cmd, got.Args) + }) + } +} diff --git a/lib/utils/fs.go b/lib/utils/fs.go index 0ddb0af13a1d9..6831a4a482b71 100644 --- a/lib/utils/fs.go +++ b/lib/utils/fs.go @@ -499,3 +499,13 @@ func RecursiveCopy(src, dest string, skip func(src, dest string) (bool, error)) return nil })) } + +// CreateExclusiveFile creates a file only if it does not exist to prevent overwriting +// existing files. +func CreateExclusiveFile(path string, mode os.FileMode) (*os.File, error) { + out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode) + if err != nil { + return nil, trace.ConvertSystemError(err) + } + return out, nil +} diff --git a/lib/utils/log/slog.go b/lib/utils/log/slog.go index bfb34f4a94114..67ea8c250dcba 100644 --- a/lib/utils/log/slog.go +++ b/lib/utils/log/slog.go @@ -21,8 +21,10 @@ package log import ( "context" "fmt" + "iter" "log/slog" "reflect" + "slices" "strings" "unicode" @@ -201,3 +203,19 @@ func (a typeAttr) LogValue() slog.Value { } return slog.StringValue("nil") } + +type iterAttr[V any] struct { + iter iter.Seq[V] +} + +// IterAttr creates a [slog.LogValuer] that will defer the collection of an +// iter.Seq. +func IterAttr[V any](iter iter.Seq[V]) slog.LogValuer { + return iterAttr[V]{ + iter: iter, + } +} + +func (a iterAttr[V]) LogValue() slog.Value { + return slog.AnyValue(slices.Collect[V](a.iter)) +} diff --git a/lib/utils/unpack.go b/lib/utils/unpack.go index 35f76adbb452b..f11d2454e2f13 100644 --- a/lib/utils/unpack.go +++ b/lib/utils/unpack.go @@ -181,7 +181,7 @@ func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error { err := withDir(path, dirMode, func() error { // Create file only if it does not exist to prevent overwriting existing // files (like session recordings). - out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode) + out, err := CreateExclusiveFile(path, mode) if err != nil { return trace.ConvertSystemError(err) } diff --git a/tool/common/common.go b/tool/common/common.go index a7d9edfe7565f..d0f9ef2ae9664 100644 --- a/tool/common/common.go +++ b/tool/common/common.go @@ -206,7 +206,7 @@ func FormatMultiValueLabels(labels map[string][]string, verbose bool) string { func FormatResourceName(r types.ResourceWithLabels, verbose bool) string { if !verbose { // return the (shorter) discovered name in non-verbose mode. - discoveredName, ok := r.GetAllLabels()[types.DiscoveredNameLabel] + discoveredName, ok := GetDiscoveredResourceName(r) if ok && discoveredName != "" { return discoveredName } @@ -214,6 +214,21 @@ func FormatResourceName(r types.ResourceWithLabels, verbose bool) string { return r.GetName() } +// GetDiscoveredResourceName returns the resource original name discovered in +// the cloud by the Teleport Discovery Service. +func GetDiscoveredResourceName(r types.ResourceWithLabels) (discoveredName string, ok bool) { + discoveredName, ok = r.GetAllLabels()[types.DiscoveredNameLabel] + return +} + +// SetDiscoveredResourceName sets the original name discovered in the cloud by +// the Teleport Discovery Service. +func SetDiscoveredResourceName(r types.ResourceWithLabels, discoveredName string) { + labels := r.GetStaticLabels() + labels[types.DiscoveredNameLabel] = discoveredName + r.SetStaticLabels(labels) +} + // FormatDefault formats a zero value with its default, or if the value is not // zero it just returns the value. func FormatDefault[T comparable](val, defaultVal T) string { diff --git a/tool/tsh/common/db.go b/tool/tsh/common/db.go index ff5fe4073653a..b0ca1770eea57 100644 --- a/tool/tsh/common/db.go +++ b/tool/tsh/common/db.go @@ -652,8 +652,11 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, // certificate's DNS names. As such, connecting to 127.0.0.1 will fail // validation, so connect to localhost. host := "localhost" - cmdOpts := []dbcmd.ConnectCommandFunc{ + cmdOpts, err := makeDatabaseCommandOptions(ctx, tc, dbInfo, dbcmd.WithLocalProxy(host, addr.Port(0), profile.CACertPathForCluster(rootClusterName)), + ) + if err != nil { + return nil, trace.Wrap(err) } if requires.tunnel { cmdOpts = append(cmdOpts, dbcmd.WithNoTLS()) @@ -783,18 +786,6 @@ func onDatabaseConnect(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - opts = append(opts, - dbcmd.WithLogger(logger), - dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd), - ) - - if opts, err = maybeAddDBUserPassword(cf, tc, dbInfo, opts); err != nil { - return trace.Wrap(err) - } - if opts, err = maybeAddGCPMetadata(cf.Context, tc, dbInfo, opts); err != nil { - return trace.Wrap(err) - } - opts = maybeAddOracleOptions(cf.Context, tc, dbInfo, opts) bb := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootClusterName, opts...) cmd, err := bb.GetConnectCommand(cf.Context) @@ -975,22 +966,7 @@ func (d *databaseInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClien return nil } - var clusterClient *client.ClusterClient - err = client.RetryWithRelogin(cf.Context, tc, func() error { - clusterClient, err = tc.ConnectToCluster(cf.Context) - return trace.Wrap(err) - }) - if err != nil { - return trace.Wrap(err) - } - defer clusterClient.Close() - - profile, err := tc.ProfileStatus() - if err != nil { - return trace.Wrap(err) - } - - checker, err := services.NewAccessCheckerForRemoteCluster(cf.Context, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient) + checker, err := d.getChecker(cf.Context, tc) if err != nil { return trace.Wrap(err) } @@ -1014,6 +990,30 @@ func (d *databaseInfo) checkAndSetDefaults(cf *CLIConf, tc *client.TeleportClien return nil } +func (d *databaseInfo) getChecker(ctx context.Context, tc *client.TeleportClient) (services.AccessChecker, error) { + if d.checker != nil { + return d.checker, nil + } + var clusterClient *client.ClusterClient + var err error + err = client.RetryWithRelogin(ctx, tc, func() error { + clusterClient, err = tc.ConnectToCluster(ctx) + return trace.Wrap(err) + }) + if err != nil { + return nil, trace.Wrap(err) + } + defer clusterClient.Close() + + profile, err := tc.ProfileStatus() + if err != nil { + return nil, trace.Wrap(err) + } + + checker, err := services.NewAccessCheckerForRemoteCluster(ctx, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient) + return checker, trace.Wrap(err) +} + // databaseInfo wraps a RouteToDatabase and the corresponding database. // Its purpose is to prevent repeated fetches of the same database, by lazily // fetching and caching the database for use as needed. @@ -1024,6 +1024,7 @@ type databaseInfo struct { database types.Database // isActive indicates an active database matched this db info. isActive bool + checker services.AccessChecker mu sync.Mutex } diff --git a/tool/tsh/common/db_exec.go b/tool/tsh/common/db_exec.go new file mode 100644 index 0000000000000..f7fabf02ff6b4 --- /dev/null +++ b/tool/tsh/common/db_exec.go @@ -0,0 +1,637 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" + "sync" + + "github.com/gravitational/trace" + oteltrace "go.opentelemetry.io/otel/trace" + "golang.org/x/sync/errgroup" + + "github.com/gravitational/teleport" + apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/client/db/dbcmd" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/srv/alpnproxy" + "github.com/gravitational/teleport/lib/tlsca" + "github.com/gravitational/teleport/lib/utils" + logutils "github.com/gravitational/teleport/lib/utils/log" + "github.com/gravitational/teleport/tool/common" +) + +func onDatabaseExec(cf *CLIConf) error { + execCommand, err := newDatabaseExecCommand(cf) + if err != nil { + return trace.Wrap(err) + } + defer execCommand.close() + + return trace.Wrap(execCommand.run()) +} + +// databaseExecClient is a wrapper of client.TeleportClient that makes backend +// calls. can be mocked for testing. +type databaseExecClient interface { + close() error + getProfileStatus() *client.ProfileStatus + getAccessChecker() services.AccessChecker + issueCert(context.Context, *databaseInfo) (tls.Certificate, error) + listDatabasesWithFilter(context.Context, *proto.ListResourcesRequest) ([]types.Database, error) +} + +type databaseExecCommand struct { + cf *CLIConf + tc *client.TeleportClient + client databaseExecClient + makeCommand func(context.Context, *databaseInfo, string, string) (*exec.Cmd, error) + summary databaseExecSummary +} + +func newDatabaseExecCommand(cf *CLIConf) (*databaseExecCommand, error) { + if err := checkDatabaseExecInputFlags(cf); err != nil { + return nil, trace.Wrap(err) + } + + tc, err := makeClient(cf) + if err != nil { + return nil, trace.Wrap(err) + } + + sharedClient, err := newSharedDatabaseExecClient(cf, tc) + if err != nil { + return nil, trace.Wrap(err) + } + + commandMaker, err := newDatabaseExecCommandMaker(cf.Context, tc, sharedClient.getProfileStatus()) + if err != nil { + return nil, trace.Wrap(err) + } + + return &databaseExecCommand{ + cf: cf, + tc: tc, + client: sharedClient, + makeCommand: commandMaker.makeCommand, + }, nil +} + +func (c *databaseExecCommand) run() error { + // Fetch. + dbs, err := c.getDatabases() + if err != nil { + return trace.Wrap(err) + } + + // Execute in parallel. + group, groupCtx := errgroup.WithContext(c.cf.Context) + group.SetLimit(c.cf.ParallelJobs) + for _, db := range dbs { + group.Go(func() error { + result := c.exec(groupCtx, db) + c.summary.add(result) + return nil + }) + } + + // Print summary. + defer func() { + switch { + case c.cf.OutputDir != "": + c.summary.printAndSave(c.cf.Stdout(), c.cf.OutputDir) + case len(dbs) > 1: + c.summary.print(c.cf.Stdout()) + } + }() + + return trace.Wrap(group.Wait()) +} + +func (c *databaseExecCommand) close() { + if err := c.client.close(); err != nil { + logger.WarnContext(c.cf.Context, "Failed to close client", "error", err) + } +} + +func checkDatabaseExecInputFlags(cf *CLIConf) error { + // Pick an arbitrary number for max connections to avoid flooding the + // backend. The limit can be overwritten with the "TELEPORT_PARALLEL_JOBS" + // env var. + const maxParallelJobs = 10 + if cf.ParallelJobs < 1 || cf.ParallelJobs > maxParallelJobs { + return trace.BadParameter(`--parallel must be between 1 and %v`, maxParallelJobs) + } + + // Selection flags. + byNames := cf.DatabaseServices != "" + bySearch := cf.SearchKeywords != "" || cf.Labels != "" + switch { + case !byNames && !bySearch: + return trace.BadParameter("please provide one of --dbs, --labels, --search flags") + case byNames && bySearch: + return trace.BadParameter("--labels/--search flags cannot be used with --dbs flag") + } + + // Logging. + if cf.ParallelJobs > 1 && cf.OutputDir == "" { + return trace.BadParameter("--output-dir must be set when executing concurrent connections") + } + if cf.OutputDir != "" && utils.FileExists(cf.OutputDir) { + return trace.BadParameter("directory %q already exists", cf.OutputDir) + } + return nil +} + +func (c *databaseExecCommand) getDatabases() ([]types.Database, error) { + if c.cf.DatabaseServices != "" { + return c.getDatabasesByNames() + } + return c.searchDatabases() +} + +func (c *databaseExecCommand) getDatabasesByNames() ([]types.Database, error) { + // Use a single predicate to search multiple names in one shot but batch 100 + // names at a time. Extra validation will be performed afterward to ensure + // we fetched what we need. + fmt.Fprintln(c.cf.Stdout(), "Fetching databases by name ...") + + // Trim spaces. + names := stringFlagToStrings(c.cf.DatabaseServices) + + var dbs []types.Database + for page := range slices.Chunk(names, 100) { + var predicate string + for _, name := range page { + predicate = makePredicateDisjunction(predicate, makeDiscoveredNameOrNamePredicate(name)) + } + + logger.DebugContext(c.cf.Context, "Getting database by name", "databases", page) + pageDBs, err := c.client.listDatabasesWithFilter(c.cf.Context, &proto.ListResourcesRequest{ + Namespace: apidefaults.Namespace, + ResourceType: types.KindDatabaseServer, + PredicateExpression: predicate, + }) + if err != nil { + return nil, trace.Wrap(err) + } + dbs = append(dbs, pageDBs...) + } + + logger.DebugContext(c.cf.Context, "Fetched databases by name", + "databases", logutils.IterAttr(types.ResourceNames(dbs))) + return dbs, trace.Wrap(ensureEachDatabase(names, dbs)) +} + +func (c *databaseExecCommand) searchDatabases() (databases []types.Database, err error) { + fmt.Fprintln(c.cf.Stdout(), "Searching databases ...") + filter := c.tc.ResourceFilter(types.KindDatabaseServer) + + logger.DebugContext(c.cf.Context, "Searching for databases", "filter", filter) + dbs, err := c.client.listDatabasesWithFilter(c.cf.Context, filter) + if err != nil { + return nil, trace.Wrap(err) + } + + logger.DebugContext(c.cf.Context, "Fetched databases with search filter", + "databases", logutils.IterAttr(types.ResourceNames(dbs)), + ) + return dbs, trace.Wrap(c.printSearchResultAndConfirm(dbs)) +} + +func (c *databaseExecCommand) printSearchResultAndConfirm(dbs []types.Database) error { + if len(dbs) == 0 { + return trace.NotFound("no databases found") + } + + fmt.Fprintf(c.cf.Stdout(), "Found %d database(s):\n\n", len(dbs)) + printTableForDatabaseExec(c.cf.Stdout(), dbs) + question := fmt.Sprintf("Do you want to proceed with %d database(s)?", len(dbs)) + if err := c.cf.PromptConfirmation(question); err != nil { + return trace.Wrap(err) + } + return nil +} + +func (c *databaseExecCommand) exec(ctx context.Context, db types.Database) (result databaseExecResult) { + result = databaseExecResult{ + RouteToDatabase: client.RouteToDatabaseToProto(c.makeRouteToDatabase(db)), + Command: c.cf.DatabaseCommand, + Success: true, + } + + printErrorAndMakeErrorResult := func(err error) databaseExecResult { + fmt.Fprintf(c.cf.Stderr(), "Failed to execute command for %q: %v\n", db.GetName(), err) + result.Success = false + result.Error = err.Error() + return result + } + + if ctx.Err() != nil { + return printErrorAndMakeErrorResult(ctx.Err()) + } + + outputWriter := c.cf.Stdout() + errWriter := c.cf.Stderr() + switch { + case c.cf.OutputDir != "": + // Use full-name instead of display name for output path. + logFile, err := c.openOutputFile(db.GetName()) + if err != nil { + return printErrorAndMakeErrorResult(err) + } + defer logFile.Close() + outputWriter = logFile + errWriter = logFile + fmt.Fprintf(c.cf.Stdout(), "Executing command for %q. Output will be saved at %q.\n", db.GetName(), logFile.Name()) + + // Save absolute path in the summary. Not expecting the absolute check + // to fail but use the filename in case it does. + if result.OutputFile, err = filepath.Abs(logFile.Name()); err != nil { + result.OutputFile = filepath.Base(logFile.Name()) + } + default: + // No prefix so output can still be copy-pasted. Extra empty line to + // separate sequential executions. + fmt.Fprintf(c.cf.Stdout(), "\nExecuting command for %q.\n", db.GetName()) + } + + var err error + result.ExitCode, err = c.runCommand(ctx, db, outputWriter, errWriter) + if err != nil { + return printErrorAndMakeErrorResult(err) + } + return result +} + +func (c *databaseExecCommand) runCommand(ctx context.Context, db types.Database, outputWriter, errWriter io.Writer) (int, error) { + dbInfo, err := c.makeDatabaseInfo(db) + if err != nil { + return 0, trace.Wrap(err) + } + lp, err := c.startLocalProxy(ctx, dbInfo) + if err != nil { + return 0, trace.Wrap(err) + } + defer lp.Close() + + dbCmd, err := c.makeCommand(ctx, dbInfo, lp.GetAddr(), c.cf.DatabaseCommand) + if err != nil { + return 0, trace.Wrap(err) + } + dbCmd.Stdout = outputWriter + dbCmd.Stderr = errWriter + + logger.DebugContext(ctx, "Executing database command", "command", dbCmd, "db", dbInfo.ServiceName) + runErr := c.cf.RunCommand(dbCmd) + if dbCmd.ProcessState != nil { + return dbCmd.ProcessState.ExitCode(), trace.Wrap(runErr) + } + return 0, trace.Wrap(runErr) +} + +func (c *databaseExecCommand) startLocalProxy(ctx context.Context, dbInfo *databaseInfo) (*alpnproxy.LocalProxy, error) { + // Issue single-use certificate. + clientCert, err := c.client.issueCert(ctx, dbInfo) + if err != nil { + return nil, trace.Wrap(err) + } + + // Do not provide a re-issuer middleware to the local proxy. The local proxy + // is meant for one-time use so there is no need to re-issue the + // certificates. + listener, err := createLocalProxyListener("localhost:0", dbInfo.RouteToDatabase, c.client.getProfileStatus()) + if err != nil { + return nil, trace.Wrap(err) + } + + opts := []alpnproxy.LocalProxyConfigOpt{ + alpnproxy.WithDatabaseProtocol(dbInfo.Protocol), + alpnproxy.WithClusterCAsIfConnUpgrade(ctx, c.tc.RootClusterCACertPool), + alpnproxy.WithClientCert(clientCert), + } + + lpConfig := makeBasicLocalProxyConfig(ctx, c.tc, listener, c.cf.InsecureSkipVerify) + lp, err := alpnproxy.NewLocalProxy(lpConfig, opts...) + if err != nil { + return nil, trace.Wrap(err) + } + + go func() { + defer listener.Close() + if err := lp.Start(ctx); err != nil { + logger.ErrorContext(ctx, "Failed to start local proxy", "error", err) + } + }() + return lp, nil +} + +func (c *databaseExecCommand) makeRouteToDatabase(db types.Database) tlsca.RouteToDatabase { + return tlsca.RouteToDatabase{ + ServiceName: db.GetName(), + Protocol: db.GetProtocol(), + Username: c.cf.DatabaseUser, + Database: c.cf.DatabaseName, + Roles: requestedDatabaseRoles(c.cf), + } +} + +func (c *databaseExecCommand) makeDatabaseInfo(db types.Database) (*databaseInfo, error) { + dbInfo := &databaseInfo{ + RouteToDatabase: c.makeRouteToDatabase(db), + database: db, + checker: c.client.getAccessChecker(), + } + return dbInfo, trace.Wrap(dbInfo.checkAndSetDefaults(c.cf, c.tc)) +} + +func (c *databaseExecCommand) outputFilename(dbServiceName string) string { + return filepath.Join(c.cf.OutputDir, dbServiceName+".output") +} + +func (c *databaseExecCommand) openOutputFile(dbServiceName string) (*os.File, error) { + logFilePath, err := utils.EnsureLocalPath(c.outputFilename(dbServiceName), "", "") + if err != nil { + return nil, trace.Wrap(err) + } + logFile, err := utils.CreateExclusiveFile(logFilePath, teleport.FileMaskOwnerOnly) + return logFile, trace.ConvertSystemError(err) +} + +// sharedDatabaseExecClient is a wrapper of client.TeleportClient that makes +// backend calls while using a shared ClusterClient. +type sharedDatabaseExecClient struct { + profile *client.ProfileStatus + clusterClient *client.ClusterClient + accessChecker services.AccessChecker + tracer oteltrace.Tracer +} + +func newSharedDatabaseExecClient(cf *CLIConf, tc *client.TeleportClient) (*sharedDatabaseExecClient, error) { + var clusterClient *client.ClusterClient + var err error + if err := client.RetryWithRelogin(cf.Context, tc, func() error { + clusterClient, err = tc.ConnectToCluster(cf.Context) + return trace.Wrap(err) + }); err != nil { + return nil, trace.Wrap(err) + } + + profile, err := tc.ProfileStatus() + if err != nil { + return nil, trace.Wrap(err) + } + + accessChecker, err := services.NewAccessCheckerForRemoteCluster(cf.Context, profile.AccessInfo(), tc.SiteName, clusterClient.AuthClient) + if err != nil { + return nil, trace.Wrap(err) + } + + return &sharedDatabaseExecClient{ + profile: profile, + tracer: cf.TracingProvider.Tracer(teleport.ComponentTSH), + clusterClient: clusterClient, + accessChecker: accessChecker, + }, nil +} + +func (c *sharedDatabaseExecClient) close() error { + if err := c.clusterClient.Close(); err != nil && !trace.IsConnectionProblem(err) { + return trace.Wrap(err) + } + return nil +} + +func (c *sharedDatabaseExecClient) getAccessChecker() services.AccessChecker { + return c.accessChecker +} + +func (c *sharedDatabaseExecClient) getProfileStatus() *client.ProfileStatus { + return c.profile +} + +// issueCert issues a single use cert for the db route. +func (c *sharedDatabaseExecClient) issueCert(ctx context.Context, dbInfo *databaseInfo) (tls.Certificate, error) { + // TODO(greedy52) add support for multi-session MFA. + params := client.ReissueParams{ + RouteToDatabase: client.RouteToDatabaseToProto(dbInfo.RouteToDatabase), + AccessRequests: c.profile.ActiveRequests, + } + + keyRing, _, err := c.clusterClient.IssueUserCertsWithMFA(ctx, params) + if err != nil { + return tls.Certificate{}, trace.Wrap(err) + } + dbCert, err := keyRing.DBTLSCert(dbInfo.RouteToDatabase.ServiceName) + return dbCert, trace.Wrap(err) +} + +func (c *sharedDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, filter *proto.ListResourcesRequest) (databases []types.Database, err error) { + ctx, span := c.tracer.Start( + ctx, + "listDatabasesWithFilter", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + servers, err := apiclient.GetAllResources[types.DatabaseServer](ctx, c.clusterClient.AuthClient, filter) + if err != nil { + return nil, trace.Wrap(err) + } + return types.DatabaseServers(servers).ToDatabases(), nil +} + +type databaseExecCommandMaker struct { + tc *client.TeleportClient + profile *client.ProfileStatus + rootCluster string +} + +func newDatabaseExecCommandMaker(ctx context.Context, tc *client.TeleportClient, profile *client.ProfileStatus) (*databaseExecCommandMaker, error) { + rootCluster, err := tc.RootClusterName(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + return &databaseExecCommandMaker{ + tc: tc, + profile: profile, + rootCluster: rootCluster, + }, nil +} + +func (m *databaseExecCommandMaker) makeCommand(ctx context.Context, dbInfo *databaseInfo, lpAddr, command string) (*exec.Cmd, error) { + addr, err := utils.ParseAddr(lpAddr) + if err != nil { + return nil, trace.Wrap(err) + } + opts, err := makeDatabaseCommandOptions(ctx, m.tc, dbInfo, + dbcmd.WithLocalProxy("localhost", addr.Port(0), ""), + dbcmd.WithNoTLS(), + ) + if err != nil { + return nil, trace.Wrap(err) + } + return dbcmd.NewCmdBuilder(m.tc, m.profile, dbInfo.RouteToDatabase, m.rootCluster, opts...). + GetExecCommand(ctx, command) +} + +// ensureEachDatabase ensures one to one mapping between the provided database +// target names and database resources. +// +// Note that it is assumed that the provided database resource has at least one +// matching names as they are retrieved from the backend based on one of those +// names. +func ensureEachDatabase(names []string, dbs []types.Database) error { + byDiscoveredNameOrName := map[string]types.Databases{} + for _, db := range dbs { + // Database may be listed by their original name in the cloud. + byDiscoveredNameOrName[db.GetName()] = append(byDiscoveredNameOrName[db.GetName()], db) + + if discoveredName, ok := common.GetDiscoveredResourceName(db); ok && discoveredName != db.GetName() { + byDiscoveredNameOrName[discoveredName] = append(byDiscoveredNameOrName[discoveredName], db) + } + } + + for _, name := range names { + matched := byDiscoveredNameOrName[name] + switch len(matched) { + case 0: + return trace.NotFound("database %q not found", name) + case 1: + continue + default: + var sb strings.Builder + printTableForDatabaseExec(&sb, matched) + return trace.BadParameter(`%q matches multiple databases: +%vTry selecting the database with a more specific name printed in the above table`, name, sb.String()) + } + } + + return nil +} + +func printTableForDatabaseExec(w io.Writer, dbs []types.Database) { + rows := make([]databaseTableRow, 0, len(dbs)) + for _, db := range dbs { + // Always use full name but don't print hidden labels. + row := getDatabaseRow("", "", "", db, nil, nil, false) + row.DisplayName = db.GetName() + rows = append(rows, row) + } + printDatabaseTable(printDatabaseTableConfig{ + writer: w, + rows: rows, + includeColumns: []string{"Name", "Protocol", "Description", "Labels"}, + }) +} + +type databaseExecResult struct { + proto.RouteToDatabase `json:"database"` + Command string `json:"command"` + OutputFile string `json:"output_file,omitempty"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` + ExitCode int `json:"exit_code"` +} + +type databaseExecSummary struct { + Databases []databaseExecResult `json:"databases"` + Success int `json:"success"` + Failure int `json:"failure"` + Total int `json:"total"` + + mu sync.Mutex +} + +func (s *databaseExecSummary) add(result databaseExecResult) { + s.mu.Lock() + defer s.mu.Unlock() + s.Databases = append(s.Databases, result) + s.Total++ + if result.Success { + s.Success++ + } else { + s.Failure++ + } +} + +func (s *databaseExecSummary) print(w io.Writer) { + s.mu.Lock() + defer s.mu.Unlock() + fmt.Fprintf(w, "\nSummary: %d of %d succeeded.\n", s.Success, s.Total) +} + +func (s *databaseExecSummary) printAndSave(w io.Writer, outputDir string) { + s.print(w) + if err := s.save(w, outputDir); err != nil { + fmt.Fprintf(w, "Failed to save summary: %v\n", err) + } +} + +func (s *databaseExecSummary) save(w io.Writer, outputDir string) error { + summaryPath := filepath.Join(outputDir, "summary.json") + summaryPath, err := utils.EnsureLocalPath(summaryPath, "", "") + if err != nil { + return trace.Wrap(err) + } + summaryFile, err := utils.CreateExclusiveFile(summaryPath, teleport.FileMaskOwnerOnly) + if trace.IsAlreadyExists(err) { + fmt.Fprintf(w, "Warning: file %s exists and will be overwritten.\n", summaryPath) + summaryFile, err = os.OpenFile(summaryPath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, teleport.FileMaskOwnerOnly) + if err != nil { + return trace.ConvertSystemError(err) + } + } else if err != nil { + return trace.Wrap(err) + } + defer summaryFile.Close() + + summaryData, err := s.makeSummaryJSONData() + if err != nil { + return trace.Wrap(err) + } + + if _, err := summaryFile.Write(summaryData); err != nil { + return trace.ConvertSystemError(err) + } + + fmt.Fprintf(w, "Summary is saved at %q.\n", summaryPath) + return nil +} + +func (s *databaseExecSummary) makeSummaryJSONData() ([]byte, error) { + s.mu.Lock() + defer s.mu.Unlock() + + summaryData, err := json.MarshalIndent(s, "", " ") + return summaryData, trace.Wrap(err) +} diff --git a/tool/tsh/common/db_exec_test.go b/tool/tsh/common/db_exec_test.go new file mode 100644 index 0000000000000..dc99644fd1b04 --- /dev/null +++ b/tool/tsh/common/db_exec_test.go @@ -0,0 +1,525 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "bytes" + "context" + "crypto/tls" + "fmt" + "os" + "os/exec" + "path/filepath" + "slices" + "strings" + "testing" + + "github.com/gravitational/trace" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/tool/common" +) + +func Test_checkDatabaseExecInputFlags(t *testing.T) { + dir := t.TempDir() + + tests := []struct { + name string + cf *CLIConf + checkError require.ErrorAssertionFunc + }{ + { + name: "with database services", + cf: &CLIConf{ + ParallelJobs: 1, + DatabaseServices: "db1,db2", + }, + checkError: require.NoError, + }, + { + name: "with search", + cf: &CLIConf{ + ParallelJobs: 1, + SearchKeywords: "dev", + }, + checkError: require.NoError, + }, + { + name: "with labels", + cf: &CLIConf{ + ParallelJobs: 1, + Labels: "env=dev", + }, + checkError: require.NoError, + }, + { + name: "invalid max connections", + cf: &CLIConf{ + ParallelJobs: 15, + Labels: "env=dev", + }, + checkError: require.Error, + }, + { + name: "missing selection", + cf: &CLIConf{ + ParallelJobs: 1, + }, + checkError: require.Error, + }, + { + name: "too many selection options", + cf: &CLIConf{ + ParallelJobs: 1, + Labels: "env=dev", + DatabaseServices: "db1,db2", + }, + checkError: require.Error, + }, + { + name: "missing output dir", + cf: &CLIConf{ + ParallelJobs: 5, + Labels: "env=dev", + }, + checkError: require.Error, + }, + { + name: "output dir exists", + cf: &CLIConf{ + ParallelJobs: 5, + OutputDir: dir, + Labels: "env=dev", + }, + checkError: require.Error, + }, + { + name: "max connections and output dir", + cf: &CLIConf{ + ParallelJobs: 5, + OutputDir: filepath.Join(dir, "output"), + Labels: "env=dev", + }, + checkError: require.NoError, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.checkError(t, checkDatabaseExecInputFlags(tt.cf)) + }) + } +} + +func TestDatabaseExec(t *testing.T) { + // Populate fake client. + cert, err := tls.X509KeyPair([]byte(fixtures.TLSCACertPEM), []byte(fixtures.TLSCAKeyPEM)) + require.NoError(t, err) + + fakeClient := &fakeDatabaseExecClient{ + cert: cert, + allDatabaseServers: []types.DatabaseServer{ + mustMakeDatabaseServerForEnv(t, "pg1", types.DatabaseProtocolPostgreSQL, "dev"), + mustMakeDatabaseServerForEnv(t, "pg2", types.DatabaseProtocolPostgreSQL, "dev"), + mustMakeDatabaseServerForEnv(t, "pg3", types.DatabaseProtocolPostgreSQL, "prod"), + mustMakeDatabaseServerForEnv(t, "mysql", types.DatabaseProtocolMySQL, "prod"), + mustMakeDatabaseServerForEnv(t, "mongo", types.DatabaseProtocolMongoDB, "staging"), + }, + } + + // Commands are not actually being run but passed to cf.RunCommand. + // Here just passing the query through the command for verification. + dbQuery := "db-query" + makeCommand := func(_ context.Context, dbInfo *databaseInfo, _ string, dbQuery string) (*exec.Cmd, error) { + return exec.Command(dbQuery), nil + } + verifyDBQuery := func(cmd *exec.Cmd) error { + if !slices.Equal(cmd.Args, []string{dbQuery}) { + return trace.CompareFailed("expect %q but got %q", dbQuery, cmd.Args) + } + fmt.Fprintln(cmd.Stdout, dbQuery, "executed") + return nil + } + + tests := []struct { + name string + setup func(*testing.T, *databaseExecCommand) + wantError string + expectOutputContains []string + verifyDir func(t *testing.T, dir string) + }{ + { + name: "no databases found by search", + setup: func(_ *testing.T, cmd *databaseExecCommand) { + cmd.cf.SearchKeywords = "not-found" + }, + wantError: "no databases found", + }, + { + name: "no databases found by names", + setup: func(_ *testing.T, cmd *databaseExecCommand) { + cmd.cf.DatabaseServices = "not-found" + }, + wantError: "not found", + }, + { + name: "by names", + setup: func(_ *testing.T, cmd *databaseExecCommand) { + cmd.cf.DatabaseServices = "pg1,pg2,pg3" + }, + expectOutputContains: []string{ + "Fetching databases by name", + "Executing command for \"pg1\".", + "db-query executed", + "Executing command for \"pg2\".", + "Executing command for \"pg3\".", + "Summary:", + }, + }, + { + name: "by keyword", + setup: func(_ *testing.T, cmd *databaseExecCommand) { + cmd.cf.SearchKeywords = "mysql" + }, + expectOutputContains: []string{ + "Found 1 database(s)", + "Name Description Protocol Labels", + "----- ----------- -------- --------", + "mysql mysql env=prod", + "Executing command for \"mysql\".", + "db-query executed", + }, + }, + { + name: "by env", + setup: func(_ *testing.T, cmd *databaseExecCommand) { + cmd.cf.Labels = "env=dev" + }, + expectOutputContains: []string{ + "Found 2 database(s)", + "Name Description Protocol Labels", + "---- ----------- -------- -------", + "pg1 postgres env=dev", + "pg2 postgres env=dev", + "Executing command for \"pg1\".", + "db-query executed", + "Executing command for \"pg2\".", + "Summary:", + }, + }, + { + name: "output dir", + setup: func(_ *testing.T, cmd *databaseExecCommand) { + cmd.cf.DatabaseServices = "pg3,mysql" + cmd.cf.OutputDir = filepath.Join(cmd.cf.HomePath, "test-output") + }, + expectOutputContains: []string{ + "Fetching databases by name", + "Executing command for \"pg3\". Output will be saved at", + "Executing command for \"mysql\". Output will be saved at", + "Summary:", + "Summary is saved", + }, + verifyDir: func(t *testing.T, dir string) { + t.Helper() + read, err := utils.ReadPath(filepath.Join(dir, "test-output", "pg3.output")) + require.NoError(t, err) + require.Equal(t, "db-query executed", strings.TrimSpace(string(read))) + read, err = utils.ReadPath(filepath.Join(dir, "test-output", "mysql.output")) + require.NoError(t, err) + require.Equal(t, "db-query executed", strings.TrimSpace(string(read))) + require.True(t, utils.FileExists(filepath.Join(dir, "test-output", "summary.json"))) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir := t.TempDir() + + // Prep CLIConf. + var capture bytes.Buffer + writer := utils.NewSyncWriter(&capture) + cf := &CLIConf{ + Proxy: "proxy:3080", + Context: context.Background(), + HomePath: dir, + ParallelJobs: 1, + DatabaseUser: "db-user", + DatabaseName: "db-name", + DatabaseCommand: dbQuery, + cmdRunner: verifyDBQuery, + OverrideStdout: writer, + overrideStderr: writer, + Confirm: false, + } + + // Prep command and sanity check. + c := &databaseExecCommand{ + cf: cf, + client: fakeClient, + makeCommand: makeCommand, + } + tt.setup(t, c) + mustCreateEmptyProfile(t, cf) + c.tc, err = makeClient(cf) + require.NoError(t, err) + require.NoError(t, checkDatabaseExecInputFlags(c.cf)) + + runError := c.run() + if tt.wantError != "" { + require.Error(t, runError) + require.Contains(t, runError.Error(), tt.wantError) + return + } + + output := capture.String() + for _, expect := range tt.expectOutputContains { + require.Contains(t, output, expect) + } + + if tt.verifyDir != nil { + tt.verifyDir(t, dir) + } + }) + } + +} + +type fakeDatabaseExecClient struct { + cert tls.Certificate + allDatabaseServers []types.DatabaseServer +} + +func (c *fakeDatabaseExecClient) close() error { + return nil +} +func (c *fakeDatabaseExecClient) getProfileStatus() *client.ProfileStatus { + return &client.ProfileStatus{} +} +func (c *fakeDatabaseExecClient) getAccessChecker() services.AccessChecker { + return services.NewAccessCheckerWithRoleSet(&services.AccessInfo{}, "clustername", services.NewRoleSet()) +} +func (c *fakeDatabaseExecClient) issueCert(context.Context, *databaseInfo) (tls.Certificate, error) { + return c.cert, nil +} +func (c *fakeDatabaseExecClient) listDatabasesWithFilter(ctx context.Context, req *proto.ListResourcesRequest) ([]types.Database, error) { + filter := services.MatchResourceFilter{ + ResourceKind: req.ResourceType, + Labels: req.Labels, + SearchKeywords: req.SearchKeywords, + } + if req.PredicateExpression != "" { + expression, err := services.NewResourceExpression(req.PredicateExpression) + if err != nil { + return nil, trace.Wrap(err) + } + filter.PredicateExpression = expression + } + + var filtered []types.Database + for _, dbServer := range c.allDatabaseServers { + match, err := services.MatchResourceByFilters(dbServer, filter, nil) + if err != nil { + return nil, trace.Wrap(err) + } else if match { + filtered = append(filtered, dbServer.GetDatabase()) + } + } + return filtered, nil +} + +func mustMakeDatabaseServer(t *testing.T, db types.Database) types.DatabaseServer { + t.Helper() + + dbV3, ok := db.(*types.DatabaseV3) + require.True(t, ok) + + server, err := types.NewDatabaseServerV3(types.Metadata{ + Name: db.GetName(), + }, types.DatabaseServerSpecV3{ + Version: teleport.Version, + Hostname: "hostname", + HostID: "host-id", + Database: dbV3, + ProxyIDs: []string{"proxy"}, + }) + require.NoError(t, err) + return server +} + +func mustMakeDatabaseForEnv(t *testing.T, name, protocol, env string) types.Database { + t.Helper() + db, err := types.NewDatabaseV3( + types.Metadata{ + Name: name, + Labels: map[string]string{"env": env}, + }, + types.DatabaseSpecV3{ + Protocol: protocol, + URI: "localhost:12345", + }, + ) + require.NoError(t, err) + return db +} + +func mustMakeDatabaseServerForEnv(t *testing.T, name, protocol, env string) types.DatabaseServer { + t.Helper() + db := mustMakeDatabaseForEnv(t, name, protocol, env) + return mustMakeDatabaseServer(t, db) +} + +func Test_ensureEachDatabase(t *testing.T) { + devDB := mustMakeDatabaseForEnv(t, "dev", "postgres", "dev") + stagingDB := mustMakeDatabaseForEnv(t, "staging", "postgres", "staging") + prodDB1 := mustMakeDatabaseForEnv(t, "prod", "postgres", "prod") + prodDB2 := mustMakeDatabaseForEnv(t, "prod", "postgres", "prod") + common.SetDiscoveredResourceName(stagingDB, "staging") // edge case where discovered name is the same. + common.SetDiscoveredResourceName(prodDB1, "prod-cloud1") + common.SetDiscoveredResourceName(prodDB2, "prod-cloud2") + + tests := []struct { + name string + inputNames []string + inputDatabases []types.Database + expectErrorContains string + }{ + { + name: "exact match", + inputNames: []string{"dev", "staging"}, + inputDatabases: []types.Database{devDB, stagingDB}, + }, + { + name: "discovered name match", + inputNames: []string{"prod-cloud1", "prod-cloud2", "dev"}, + inputDatabases: []types.Database{devDB, prodDB1, prodDB2}, + }, + { + name: "database not found", + inputNames: []string{"dev", "staging", "prod-cloud5"}, + inputDatabases: []types.Database{devDB, stagingDB}, + expectErrorContains: "\"prod-cloud5\" not found", + }, + { + name: "ambiguous name", + inputNames: []string{"prod"}, + inputDatabases: []types.Database{prodDB1, prodDB2}, + expectErrorContains: "\"prod\" matches multiple databases", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ensureEachDatabase(tt.inputNames, tt.inputDatabases) + if tt.expectErrorContains == "" { + require.NoError(t, err) + } else { + require.Error(t, err) + require.Contains(t, err.Error(), tt.expectErrorContains) + } + }) + } +} + +func Test_databaseExecSummary(t *testing.T) { + summary := databaseExecSummary{} + summary.add(databaseExecResult{ + RouteToDatabase: proto.RouteToDatabase{ + ServiceName: "db1", + Protocol: "postgres", + Username: "db-user", + }, + Success: true, + }) + summary.add(databaseExecResult{ + RouteToDatabase: proto.RouteToDatabase{ + ServiceName: "db2", + Protocol: "postgres", + Username: "db-user", + }, + Error: "some error", + ExitCode: 1, + }) + summary.add(databaseExecResult{ + RouteToDatabase: proto.RouteToDatabase{ + ServiceName: "db3", + Protocol: "postgres", + Username: "db-user", + }, + Success: true, + }) + + var buf bytes.Buffer + summary.print(&buf) + require.Contains(t, buf.String(), "Summary: 2 of 3 succeeded") + + buf.Reset() + dir := t.TempDir() + expectPath := filepath.Join(dir, "summary.json") + summary.printAndSave(&buf, dir) + require.Contains(t, buf.String(), "Summary: 2 of 3 succeeded") + require.Contains(t, buf.String(), fmt.Sprintf("Summary is saved at %q", expectPath)) + summaryData, err := os.ReadFile(expectPath) + require.NoError(t, err) + require.Equal(t, `{ + "databases": [ + { + "database": { + "service_name": "db1", + "protocol": "postgres", + "username": "db-user" + }, + "command": "", + "success": true, + "exit_code": 0 + }, + { + "database": { + "service_name": "db2", + "protocol": "postgres", + "username": "db-user" + }, + "command": "", + "success": false, + "error": "some error", + "exit_code": 1 + }, + { + "database": { + "service_name": "db3", + "protocol": "postgres", + "username": "db-user" + }, + "command": "", + "success": true, + "exit_code": 0 + } + ], + "success": 2, + "failure": 1, + "total": 3 +}`, string(summaryData)) +} diff --git a/tool/tsh/common/db_print.go b/tool/tsh/common/db_print.go index c79a44d5dec64..b86e856584fba 100644 --- a/tool/tsh/common/db_print.go +++ b/tool/tsh/common/db_print.go @@ -79,9 +79,20 @@ type printDatabaseTableConfig struct { rows []databaseTableRow showProxyAndCluster bool verbose bool + // includeColumns specifies a whitelist of columns to include. verbose and + // showProxyAndCluster are ignored when includeColumns is provided. + includeColumns []string } -func (cfg printDatabaseTableConfig) excludeColumns() (out []string) { +func (cfg printDatabaseTableConfig) excludeColumns(allColumns []string) (out []string) { + if len(cfg.includeColumns) > 0 { + for _, column := range allColumns { + if !slices.Contains(cfg.includeColumns, column) { + out = append(out, column) + } + } + return + } if !cfg.showProxyAndCluster { out = append(out, "Proxy", "Cluster") } @@ -94,7 +105,7 @@ func (cfg printDatabaseTableConfig) excludeColumns() (out []string) { func printDatabaseTable(cfg printDatabaseTableConfig) { allColumns := makeTableColumnTitles(databaseTableRow{}) rowsWithAllColumns := makeTableRows(cfg.rows) - excludeColumns := cfg.excludeColumns() + excludeColumns := cfg.excludeColumns(allColumns) var printColumns []string printRows := make([][]string, len(cfg.rows)) diff --git a/tool/tsh/common/db_print_test.go b/tool/tsh/common/db_print_test.go index 23593cb6f2b48..22ca4b6a9a8f8 100644 --- a/tool/tsh/common/db_print_test.go +++ b/tool/tsh/common/db_print_test.go @@ -106,6 +106,19 @@ db2 describe db2 mysql self-hosted localhost:3306 [alice] [readonly] proxy cluster1 db1 describe db1 postgres self-hosted localhost:5432 [*] Env=dev tsh db connect db1 proxy cluster1 db2 describe db2 mysql self-hosted localhost:3306 [alice] [readonly] Env=prod +`, + }, + { + name: "tsh db exec search results", + cfg: printDatabaseTableConfig{ + rows: rows, + includeColumns: []string{"Name", "Protocol", "Description", "Labels"}, + }, + expect: `Name Description Protocol Labels +---- ------------ -------- -------- +db1 describe db1 postgres Env=dev +db2 describe db2 mysql Env=prod + `, }, } diff --git a/tool/tsh/common/git_list_test.go b/tool/tsh/common/git_list_test.go index cf4d52e318043..15a40a625e94f 100644 --- a/tool/tsh/common/git_list_test.go +++ b/tool/tsh/common/git_list_test.go @@ -128,21 +128,14 @@ func TestGitListCommand(t *testing.T) { } // Create a empty profile so we don't ping proxy. - clientStore, err := initClientStore(cf, cf.Proxy) - require.NoError(t, err) - profile := &profile.Profile{ - SSHProxyAddr: "proxy:3023", - WebProxyAddr: "proxy:3080", - } - err = clientStore.SaveProfile(profile, true) - require.NoError(t, err) + mustCreateEmptyProfile(t, cf) cmd := gitListCommand{ format: test.format, fetchFn: test.fetchFn, } - err = cmd.run(cf) + err := cmd.run(cf) if test.wantError { require.Error(t, err) } else { @@ -154,3 +147,21 @@ func TestGitListCommand(t *testing.T) { }) } } + +// mustCreateEmptyProfile creates an empty profile so we don't ping proxy when +// calling makeClient on provided cf. +func mustCreateEmptyProfile(t *testing.T, cf *CLIConf) { + t.Helper() + + if cf.HomePath == "" { + cf.HomePath = t.TempDir() + } + + clientStore, err := initClientStore(cf, cf.Proxy) + require.NoError(t, err) + err = clientStore.SaveProfile(&profile.Profile{ + SSHProxyAddr: cf.Proxy, + WebProxyAddr: cf.Proxy, + }, true) + require.NoError(t, err) +} diff --git a/tool/tsh/common/help.go b/tool/tsh/common/help.go index 427a3be93702d..71ac21769b1db 100644 --- a/tool/tsh/common/help.go +++ b/tool/tsh/common/help.go @@ -62,4 +62,18 @@ Examples: Get database names using "jq": $ tsh db ls --format json | jq -r '.[].metadata.name'` + + dbExecHelp = ` +Examples: + Search databases with labels: + $ tsh db exec "source my_script.sql" --db-user mysql --labels key1=value1,key2=value2 + + Search databases with keywords: + $ tsh db exec "select 1" --db-user mysql --db-name mysql --search foo,bar + + Execute a command on specified target databases without confirmation: + $ tsh db exec "select @@hostname" --db-user mysql --dbs mydb1,mydb2,mydb3 --no-confirm + + Run commands in parallel, and save outputs to files: + $ tsh db exec "select 1" --db-user mysql --labels env=dev --parallel=5 --output-dir=exec-outputs` ) diff --git a/tool/tsh/common/proxy.go b/tool/tsh/common/proxy.go index 45a112a254ad3..f1faf9bc57fae 100644 --- a/tool/tsh/common/proxy.go +++ b/tool/tsh/common/proxy.go @@ -216,21 +216,15 @@ func onProxyCommandDB(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - opts := []dbcmd.ConnectCommandFunc{ + opts, err := makeDatabaseCommandOptions(cf.Context, tc, dbInfo, dbcmd.WithLocalProxy("localhost", addr.Port(0), ""), dbcmd.WithNoTLS(), - dbcmd.WithLogger(logger), dbcmd.WithPrintFormat(), dbcmd.WithTolerateMissingCLIClient(), - dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd), - } - if opts, err = maybeAddDBUserPassword(cf, tc, dbInfo, opts); err != nil { - return trace.Wrap(err) - } - if opts, err = maybeAddGCPMetadata(cf.Context, tc, dbInfo, opts); err != nil { + ) + if err != nil { return trace.Wrap(err) } - opts = maybeAddOracleOptions(cf.Context, tc, dbInfo, opts) commands, err := dbcmd.NewCmdBuilder(tc, profile, dbInfo.RouteToDatabase, rootCluster, opts..., @@ -278,9 +272,9 @@ func onProxyCommandDB(cf *CLIConf) error { return nil } -func maybeAddDBUserPassword(cf *CLIConf, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { +func maybeAddDBUserPassword(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { if dbInfo.Protocol == defaults.ProtocolCassandra { - db, err := dbInfo.GetDatabase(cf.Context, tc) + db, err := dbInfo.GetDatabase(ctx, tc) if err != nil { return nil, trace.Wrap(err) } @@ -302,6 +296,22 @@ func requiresGCPMetadata(protocol string) bool { return protocol == defaults.ProtocolSpanner } +func makeDatabaseCommandOptions(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, extraOpts ...dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { + var err error + opts := append([]dbcmd.ConnectCommandFunc{ + dbcmd.WithLogger(logger), + dbcmd.WithGetDatabaseFunc(dbInfo.getDatabaseForDBCmd), + }, extraOpts...) + + if opts, err = maybeAddDBUserPassword(ctx, tc, dbInfo, opts); err != nil { + return nil, trace.Wrap(err) + } + if opts, err = maybeAddGCPMetadata(ctx, tc, dbInfo, opts); err != nil { + return nil, trace.Wrap(err) + } + return maybeAddOracleOptions(ctx, tc, dbInfo, opts), nil +} + func maybeAddGCPMetadata(ctx context.Context, tc *libclient.TeleportClient, dbInfo *databaseInfo, opts []dbcmd.ConnectCommandFunc) ([]dbcmd.ConnectCommandFunc, error) { if !requiresGCPMetadata(dbInfo.Protocol) { return opts, nil diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 5dd1a298259e2..1107984f777f0 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -67,6 +67,7 @@ import ( "github.com/gravitational/teleport/api/types/accesslist" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/wrappers" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keys/hardwarekey" "github.com/gravitational/teleport/api/utils/keys/piv" "github.com/gravitational/teleport/api/utils/prompt" @@ -241,12 +242,17 @@ type CLIConf struct { // DatabaseService specifies the database proxy server to log into. DatabaseService string + // DatabaseServices specifies a list of database services. + DatabaseServices string // DatabaseUser specifies database user to embed in the certificate. DatabaseUser string // DatabaseName specifies database name to embed in the certificate. DatabaseName string // DatabaseRoles specifies database roles to embed in the certificate. DatabaseRoles string + // DatabaseCommand specifies the command to execute. + DatabaseCommand string + // AppName specifies proxied application name. AppName string // Interactive sessions will allocate a PTY and create interactive "shell" @@ -586,6 +592,13 @@ type CLIConf struct { // to the hardware key agent. Some commands, like login, are better with the // direct PIV service so that prompts are not split between processes. disableHardwareKeyAgentClient bool + + // ParallelJobs specifies the number of parallel jobs allowed. + ParallelJobs int + // OutputDir specifies the directory for storing command outputs. + OutputDir string + // Confirm determines whether to provide a y/N confirmation prompt. + Confirm bool } // Stdout returns the stdout writer. @@ -633,6 +646,23 @@ func (c *CLIConf) LookPath(file string) (string, error) { return exec.LookPath(file) } +// PromptConfirmation prompts the user for a yes/no confirmation for question. +// The prompt is skipped unless cf.Confirm is set. +func (c *CLIConf) PromptConfirmation(question string) error { + if !c.Confirm { + fmt.Fprintf(c.Stdout(), "Skipping confirmation for %q due to the --no-confirm flag.\n", question) + return nil + } + + ok, err := prompt.Confirmation(c.Context, c.Stdout(), prompt.Stdin(), question) + if err != nil { + return trace.Wrap(err) + } else if !ok { + return trace.Errorf("Operation canceled by user request.") + } + return nil +} + func Main() { cmdLineOrig := os.Args[1:] var cmdLine []string @@ -1030,6 +1060,18 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { dbConnect.Flag("request-reason", "Reason for requesting access").StringVar(&cf.RequestReason) dbConnect.Flag("disable-access-request", "Disable automatic resource access requests").BoolVar(&cf.disableAccessRequest) dbConnect.Flag("tunnel", "Open authenticated tunnel using database's client certificate so clients don't need to authenticate").Hidden().BoolVar(&cf.LocalProxyTunnel) + dbExec := db.Command("exec", "Execute database commands on target database services.") + dbExec.Flag("db-user", "Database user to log in as.").Short('u').StringVar(&cf.DatabaseUser) + dbExec.Flag("db-name", "Database name to log in to.").Short('n').StringVar(&cf.DatabaseName) + dbExec.Flag("db-roles", "List of comma separate database roles to use for auto-provisioned user.").Short('r').StringVar(&cf.DatabaseRoles) + dbExec.Flag("search", searchHelp).StringVar(&cf.SearchKeywords) + dbExec.Flag("labels", labelHelp).StringVar(&cf.Labels) + dbExec.Flag("parallel", "Run commands on target databases in parallel. Defaults to 1, and maximum allowed is 10.").Default("1").IntVar(&cf.ParallelJobs) + dbExec.Flag("output-dir", "Directory to store command output per target database service. A summary is saved as \"summary.json\".").StringVar(&cf.OutputDir) + dbExec.Flag("dbs", "List of comma separated target database services. Mutually exclusive with --search or --labels.").StringVar(&cf.DatabaseServices) + dbExec.Flag("confirm", "Confirm selected database services before executing command.").Default("true").BoolVar(&cf.Confirm) + dbExec.Arg("command", "Execute this command on target database services.").Required().StringVar(&cf.DatabaseCommand) + dbExec.Alias(dbExecHelp) // join join := app.Command("join", "Join the active SSH or Kubernetes session.") @@ -1592,6 +1634,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = onDatabaseConfig(&cf) case dbConnect.FullCommand(): err = onDatabaseConnect(&cf) + case dbExec.FullCommand(): + err = onDatabaseExec(&cf) case environment.FullCommand(): err = onEnvironment(&cf) case mfa.ls.FullCommand(): @@ -5867,3 +5911,14 @@ func tryLockMemory(cf *CLIConf) error { return trace.BadParameter("unexpected value for --mlock, expected one of (%v)", strings.Join(mlockModes, ", ")) } } + +// stringFlagToStrings parses a comma-separated string from a CLIConf flag into +// a slice of strings. It trims whitespace from each value and removes +// duplicates. +func stringFlagToStrings(value string) []string { + values := strings.Split(value, ",") + for i := range values { + values[i] = strings.TrimSpace(values[i]) + } + return apiutils.Deduplicate(values) +}