Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 47 additions & 49 deletions lib/client/db/dbcmd/dbcmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,17 +120,23 @@ type CLICommandBuilder struct {
// --cluster flag. Therefore profile.Cluster is not suitable for
// determining the target cluster or the root cluster. Use tc.SiteName for
// the target cluster and rootCluster for root cluster.
profile *client.ProfileStatus
db *tlsca.RouteToDatabase
host string
port int
options connectionCommandOpts
uid utils.UID
profile *client.ProfileStatus
db *tlsca.RouteToDatabase
host string
port int
options connectionCommandOpts
uid utils.UID
getDatabaseFunc GetDatabaseFunc
}

func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus,
db tlsca.RouteToDatabase, rootClusterName string, opts ...ConnectCommandFunc,
) *CLICommandBuilder {
func NewCmdBuilder(
tc *client.TeleportClient,
profile *client.ProfileStatus,
db tlsca.RouteToDatabase,
rootClusterName string,
getDatabaseFunc GetDatabaseFunc,
opts ...ConnectCommandFunc,
) (*CLICommandBuilder, error) {
var options connectionCommandOpts
for _, opt := range opts {
opt(&options)
Expand All @@ -151,16 +157,21 @@ func NewCmdBuilder(tc *client.TeleportClient, profile *client.ProfileStatus,
options.exe = &SystemExecer{}
}

return &CLICommandBuilder{
tc: tc,
profile: profile,
db: &db,
host: host,
port: port,
options: options,
rootCluster: rootClusterName,
uid: utils.NewRealUID(),
if getDatabaseFunc == nil {
return nil, trace.BadParameter("GetDatabaseFunc is required and cannot be nil")
}

return &CLICommandBuilder{
tc: tc,
profile: profile,
db: &db,
host: host,
port: port,
options: options,
rootCluster: rootClusterName,
getDatabaseFunc: getDatabaseFunc,
uid: utils.NewRealUID(),
}, nil
}

// GetConnectCommand returns a command that can connect the user directly to the given database
Expand Down Expand Up @@ -219,7 +230,7 @@ func (c *CLICommandBuilder) GetConnectCommand(ctx context.Context) (*exec.Cmd, e
return c.getClickhouseNativeCommand()

case defaults.ProtocolSpanner:
return c.getSpannerCommand()
return c.getSpannerCommand(ctx)
}

return nil, trace.BadParameter("unsupported database protocol: %v", c.db)
Expand Down Expand Up @@ -546,12 +557,7 @@ func (c *CLICommandBuilder) getMongoCommand(ctx context.Context) (*exec.Cmd, err
}

func (c *CLICommandBuilder) getDatabase(ctx context.Context) (types.Database, error) {
// Technically, we can just use tc to get the database. But caller may have
// extra logic so rely on the callback for now.
if c.options.getDatabase == nil {
return nil, trace.NotFound("missing GetDatabaseFunc")
}
db, err := c.options.getDatabase(ctx, c.tc, c.db.ServiceName)
db, err := c.getDatabaseFunc(ctx, c.tc, c.db.ServiceName)
return db, trace.Wrap(err)
}

Expand Down Expand Up @@ -749,25 +755,34 @@ func (c *CLICommandBuilder) getDynamoDBCommand() (*exec.Cmd, error) {
return exec.Command(awsBin, args...), nil
}

func (c *CLICommandBuilder) getSpannerCommand() (*exec.Cmd, error) {
func (c *CLICommandBuilder) getSpannerCommand(ctx context.Context) (*exec.Cmd, error) {
if err := c.checkLocalProxyTunnelOnly(false); err != nil {
return nil, trace.Wrap(err)
}

var (
gcp types.GCPCloudSQL
project,
instance,
database string
)
if c.options.printFormat {
// default placeholders for a print command if not all info is available

db, err := c.getDatabase(ctx)
switch {
case err != nil && c.options.printFormat:
// OK to continue in this case, we'll print the placeholders instead.
project, instance, database = "<project>", "<instance>", "<database>"
case err != nil:
return nil, trace.Wrap(err)
default:
gcp = db.GetGCP()
}

if c.options.gcp.ProjectID != "" {
project = c.options.gcp.ProjectID
if gcp.ProjectID != "" {
project = gcp.ProjectID
}
if c.options.gcp.InstanceID != "" {
instance = c.options.gcp.InstanceID
if gcp.InstanceID != "" {
instance = gcp.InstanceID
}
if c.db.Database != "" {
database = c.db.Database
Expand Down Expand Up @@ -947,9 +962,7 @@ type connectionCommandOpts struct {
log *logrus.Entry
exe Execer
password string
gcp types.GCPCloudSQL
oracle oracleOpts
getDatabase GetDatabaseFunc
}

// ConnectCommandFunc is a type for functions returned by the "With*" functions in this package.
Expand Down Expand Up @@ -1052,24 +1065,9 @@ func WithOracleOpts(canUseTCP bool, hasTCPServers bool) ConnectCommandFunc {
}
}

// WithGCP adds GCP metadata for the database command to access.
// TODO(greedy52) use GetDatabaseFunc instead.
func WithGCP(gcp types.GCPCloudSQL) ConnectCommandFunc {
return func(opts *connectionCommandOpts) {
opts.gcp = gcp
}
}

// GetDatabaseFunc is a callback to retrieve types.Database.
type GetDatabaseFunc func(context.Context, *client.TeleportClient, string) (types.Database, error)

// WithGetDatabaseFunc provides a callback to retrieve types.Database.
func WithGetDatabaseFunc(f GetDatabaseFunc) ConnectCommandFunc {
return func(opts *connectionCommandOpts) {
opts.getDatabase = f
}
}

const (
// envVarMongoServerSelectionTimeoutMS is the environment variable that
// controls the server selection timeout used for MongoDB clients.
Expand Down
Loading
Loading