Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a39bdbc
Implement basic "tsh db exec"
greedy52 Mar 12, 2025
c7cbb50
adding ut
greedy52 Mar 21, 2025
f4e7900
minor refactor, fix race, rename iter func
greedy52 Mar 24, 2025
23f9c19
add help
greedy52 Mar 24, 2025
62a980d
always use service name
greedy52 Mar 27, 2025
7894671
overwrite max connections with env var
greedy52 Mar 27, 2025
ff0c3e8
single get databases call
greedy52 Mar 27, 2025
5f5bddc
remove prefix output
greedy52 Mar 27, 2025
b28400c
fix some flags
greedy52 Mar 27, 2025
5f2acfb
iterutils
greedy52 Mar 27, 2025
687f418
ensure each database
greedy52 Mar 28, 2025
cf5b23a
add summery
greedy52 Mar 28, 2025
0508880
refactor, tests
greedy52 Mar 28, 2025
2489a98
revert auto rename change by editor
greedy52 Mar 31, 2025
598d6cb
revert migrate
greedy52 Mar 31, 2025
2d12569
Merge branch 'master' of github.com:gravitational/teleport into STeve…
greedy52 Mar 31, 2025
3ff0cb7
remove unused var
greedy52 Mar 31, 2025
b93d422
review comments
greedy52 Apr 10, 2025
678ed7d
renaming --max-connections to --parallel
greedy52 Apr 10, 2025
a0c22f2
make exec return result instead of error
greedy52 Apr 10, 2025
e79882c
hint TELEPORT_PARALLEL_JOBS
greedy52 Apr 17, 2025
73cdd48
Merge branch 'master' of github.com:gravitational/teleport into STeve…
greedy52 Apr 17, 2025
64e6c76
fix golint
greedy52 Apr 17, 2025
1300f3d
address PR comments
greedy52 Apr 24, 2025
4c8dd41
Merge branch 'master' of github.com:gravitational/teleport into STeve…
greedy52 Apr 24, 2025
b213de1
Merge branch 'master' into STeve/51679_base_tsh_db_exec
greedy52 Apr 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions api/types/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package types

import (
"iter"
"regexp"
"slices"
"sort"
Expand All @@ -30,6 +31,7 @@ import (
"github.com/gravitational/teleport/api/types/common"
"github.com/gravitational/teleport/api/types/compare"
"github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/iterutils"
)

var (
Expand Down Expand Up @@ -84,6 +86,12 @@ func GetName[R Resource](r R) string {
return r.GetName()
}

// ResourceNames creates an iterator that loops through the provided slice of
// resources and return their names.
func ResourceNames[R Resource, S ~[]R](s S) iter.Seq[string] {
return iterutils.Map(GetName, slices.Values(s))
}

// ResourceDetails includes details about the resource
type ResourceDetails struct {
Hostname string
Expand Down
19 changes: 19 additions & 0 deletions api/types/resource_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package types

import (
"fmt"
"slices"
"testing"
"time"

Expand Down Expand Up @@ -820,3 +822,20 @@ func TestResourceHeaderIsEqual(t *testing.T) {
})
}
}

func TestResourceNames(t *testing.T) {
var apps Apps
var expectedNames []string
for i := 0; i < 10; i++ {
app, err := NewAppV3(Metadata{
Name: fmt.Sprintf("app-%d", i),
}, AppSpecV3{
URI: "tcp://localhost:1111",
})
require.NoError(t, err)
apps = append(apps, app)
expectedNames = append(expectedNames, app.GetName())
}

require.Equal(t, expectedNames, slices.Collect(ResourceNames(apps)))
}
37 changes: 37 additions & 0 deletions api/utils/iterutils/iter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package iterutils

import (
"iter"
)

// Map returns an iterator over f applied to seq.
//
// Copied from https://github.com/golang/go/issues/61898. We should switch to an
// official package once it is available.
func Map[In, Out any](f func(In) Out, seq iter.Seq[In]) iter.Seq[Out] {
return func(yield func(Out) bool) {
for in := range seq {
if !yield(f(in)) {
return
}
}
}
}
39 changes: 39 additions & 0 deletions api/utils/iterutils/iter_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package iterutils

import (
"fmt"
"slices"
"strings"
)

func ExampleMap() {
inputs := []string{
"hello world",
"foo",
}

for mapped := range Map(strings.ToUpper, slices.Values(inputs)) {
fmt.Println(mapped)
}
// Output:
// HELLO WORLD
// FOO
}
60 changes: 60 additions & 0 deletions lib/client/db/dbcmd/exec.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package dbcmd

import (
"context"
"os/exec"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/lib/defaults"
)

// GetExecCommand returns a command that executes the provided query on the
// target database using an appropriate CLI database client.
func (c *CLICommandBuilder) GetExecCommand(_ context.Context, query string) (*exec.Cmd, error) {
if !c.options.noTLS || c.options.localProxyHost == "" {
return nil, trace.BadParameter("query execution is only supported when using an authenticated local proxy")
}

switch c.db.Protocol {
case defaults.ProtocolPostgres:
return c.getPostgresExecCommand(query)
case defaults.ProtocolMySQL:
return c.getMySQLExecCommand(query)
default:
return nil, trace.BadParameter("%s databases not supported for exec command", c.db.Protocol)
}
}

func (c *CLICommandBuilder) getPostgresExecCommand(query string) (*exec.Cmd, error) {
cmd := c.getPostgresCommand()
cmd.Args = append(cmd.Args, "-c", query)
return cmd, nil
}

func (c *CLICommandBuilder) getMySQLExecCommand(query string) (*exec.Cmd, error) {
cmd, err := c.getMySQLCommand()
if err != nil {
return nil, trace.Wrap(err)
}
cmd.Args = append(cmd.Args, "-e", query)
return cmd, nil
}
120 changes: 120 additions & 0 deletions lib/client/db/dbcmd/exec_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package dbcmd

import (
"context"
"testing"

"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/observability/tracing"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
)

func TestCLICommandBuilderGetExecCommand(t *testing.T) {
fakeExec := &fakeExec{
execOutput: map[string][]byte{
"psql": []byte(""),
"mysql": []byte(""),
},
}

conf := &client.Config{
HomePath: t.TempDir(),
Host: "localhost",
WebProxyAddr: "proxy.example.com",
SiteName: "db.example.com",
Tracer: tracing.NoopProvider().Tracer("test"),
}

tc, err := client.NewClient(conf)
require.NoError(t, err)

profile := &client.ProfileStatus{
Name: "example.com",
Username: "bob",
Dir: "/tmp",
Cluster: "example.com",
}

tests := []struct {
name string
opts []ConnectCommandFunc
protocol string
cmd []string
wantErr bool
}{
{
name: "not authenticated tunnel",
protocol: defaults.ProtocolPostgres,
wantErr: true,
},
{
name: "unsupported protocol",
protocol: defaults.ProtocolDynamoDB,
opts: []ConnectCommandFunc{WithNoTLS()},
wantErr: true,
},
{
name: "postgres",
protocol: defaults.ProtocolPostgres,
opts: []ConnectCommandFunc{WithNoTLS()},
cmd: []string{"psql", "postgres://db-user@localhost:12345/db-name", "-c", "select 1"},
wantErr: false,
},
{
name: "mysql",
protocol: defaults.ProtocolMySQL,
opts: []ConnectCommandFunc{WithNoTLS()},
cmd: []string{"mysql", "--user", "db-user", "--database", "db-name", "--port", "12345", "--host", "localhost", "--protocol", "TCP", "-e", "select 1"},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
database := tlsca.RouteToDatabase{
Protocol: tt.protocol,
Database: "db-name",
Username: "db-user",
ServiceName: "db-service",
}

opts := append([]ConnectCommandFunc{
WithLocalProxy("localhost", 12345, ""),
WithExecer(fakeExec),
}, tt.opts...)

c := NewCmdBuilder(tc, profile, database, "root", opts...)
c.uid = utils.NewFakeUID()
got, err := c.GetExecCommand(context.Background(), "select 1")
if tt.wantErr {
require.Error(t, err)
return
}

require.NoError(t, err)
require.Equal(t, tt.cmd, got.Args)
})
}
}
10 changes: 10 additions & 0 deletions lib/utils/fs.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,13 @@ func RecursiveCopy(src, dest string, skip func(src, dest string) (bool, error))
return nil
}))
}

// CreateExclusiveFile creates a file only if it does not exist to prevent overwriting
// existing files.
func CreateExclusiveFile(path string, mode os.FileMode) (*os.File, error) {
out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode)
if err != nil {
return nil, trace.ConvertSystemError(err)
}
return out, nil
}
18 changes: 18 additions & 0 deletions lib/utils/log/slog.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ package log
import (
"context"
"fmt"
"iter"
"log/slog"
"reflect"
"slices"
"strings"
"unicode"

Expand Down Expand Up @@ -201,3 +203,19 @@ func (a typeAttr) LogValue() slog.Value {
}
return slog.StringValue("nil")
}

type iterAttr[V any] struct {
iter iter.Seq[V]
}

// IterAttr creates a [slog.LogValuer] that will defer the collection of an
// iter.Seq.
func IterAttr[V any](iter iter.Seq[V]) slog.LogValuer {
return iterAttr[V]{
iter: iter,
}
}

func (a iterAttr[V]) LogValue() slog.Value {
return slog.AnyValue(slices.Collect[V](a.iter))
}
2 changes: 1 addition & 1 deletion lib/utils/unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ func writeFile(path string, r io.Reader, mode, dirMode os.FileMode) error {
err := withDir(path, dirMode, func() error {
// Create file only if it does not exist to prevent overwriting existing
// files (like session recordings).
out, err := os.OpenFile(path, os.O_CREATE|os.O_EXCL|os.O_WRONLY, mode)
out, err := CreateExclusiveFile(path, mode)
if err != nil {
return trace.ConvertSystemError(err)
}
Expand Down
Loading
Loading