Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 64 additions & 34 deletions tool/tsh/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package main

import (
"bytes"
"context"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -402,17 +403,11 @@ func onDatabaseEnv(cf *CLIConf) error {
}

if !dbprofile.IsSupported(*database) {
return trace.BadParameter(dbCmdUnsupportedDBProtocol,
cf.CommandWithBinary(),
defaults.ReadableDatabaseProtocol(database.Protocol),
)
return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, database))
}
// MySQL requires ALPN local proxy in signle port mode.
// MySQL requires ALPN local proxy in single port mode.
if tc.TLSRoutingEnabled && database.Protocol == defaults.ProtocolMySQL {
return trace.BadParameter(dbCmdUnsupportedTLSRouting,
cf.CommandWithBinary(),
defaults.ReadableDatabaseProtocol(database.Protocol),
)
return trace.BadParameter(formatDbCmdUnsupportedTLSRouting(cf, database))
}

env, err := dbprofile.Env(tc, *database)
Expand Down Expand Up @@ -469,17 +464,11 @@ func onDatabaseConfig(cf *CLIConf) error {
// the remote proxy directly. Return errors here when direct connection
// does NOT work (e.g. when ALPN local proxy is required).
if isLocalProxyAlwaysRequired(database.Protocol) {
return trace.BadParameter(dbCmdUnsupportedDBProtocol,
cf.CommandWithBinary(),
defaults.ReadableDatabaseProtocol(database.Protocol),
)
return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, database))
}
// MySQL requires ALPN local proxy in signle port mode.
// MySQL requires ALPN local proxy in single port mode.
if tc.TLSRoutingEnabled && database.Protocol == defaults.ProtocolMySQL {
return trace.BadParameter(dbCmdUnsupportedTLSRouting,
cf.CommandWithBinary(),
defaults.ReadableDatabaseProtocol(database.Protocol),
)
return trace.BadParameter(formatDbCmdUnsupportedTLSRouting(cf, database))
}

host, port := tc.DatabaseProxyHostPort(*database)
Expand Down Expand Up @@ -560,9 +549,10 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, tc *client.TeleportC
return []dbcmd.ConnectCommandFunc{}, nil
}

// Some protocols (Snowflake, Elasticsearch) only works in the local tunnel mode.
// Some protocols (Snowflake) only work in the local tunnel mode.
// ElasticSearch can work without the --tunnel flag, but not via `tsh db connect`.
localProxyTunnel := cf.LocalProxyTunnel
if db.Protocol == defaults.ProtocolSnowflake || db.Protocol == defaults.ProtocolElasticsearch {
if requiresLocalProxyTunnel(db.Protocol) || db.Protocol == defaults.ProtocolElasticsearch {
localProxyTunnel = true
}

Expand Down Expand Up @@ -1044,6 +1034,42 @@ func isLocalProxyAlwaysRequired(protocol string) bool {
}
}

// formatDbCmdUnsupportedWithCondition is a helper func that formats a generic unsupported DB error message.
// The condition argument is optional and can be "", but otherwise it should be a specific condition for which this DB subcommand
// is not supported, e.g. "when TLS routing is enabled" or "without using the --tunnel flag".
func formatDbCmdUnsupportedWithCondition(cf *CLIConf, database *tlsca.RouteToDatabase, condition string) string {
templateData := map[string]any{
"command": cf.CommandWithBinary(),
"protocol": defaults.ReadableDatabaseProtocol(database.Protocol),
"alternatives": getDbCmdAlternatives(cf.SiteName, database),
"condition": condition,
}

buf := bytes.NewBuffer(nil)
_ = dbCmdUnsupportedTemplate.Execute(buf, templateData)
return buf.String()
}

// formatDbCmdUnsupportedDBProtocol is a helper func that formats the unsupported DB protocol error message unconditionally.
func formatDbCmdUnsupportedDBProtocol(cf *CLIConf, database *tlsca.RouteToDatabase) string {
return formatDbCmdUnsupportedWithCondition(cf, database, "")
}

// formatDbCmdUnsupportedTLSRouting is a helper func that formats an unsupported DB Protocol error with a TLS routing condition.
func formatDbCmdUnsupportedTLSRouting(cf *CLIConf, database *tlsca.RouteToDatabase) string {
return formatDbCmdUnsupportedWithCondition(cf, database, "when TLS routing is enabled on the Teleport Proxy Service")
}

// getDbCmdAlternatives is a helper func that returns alternative tsh commands for connecting to a database.
func getDbCmdAlternatives(clusterFlag string, database *tlsca.RouteToDatabase) []string {
var alts []string
// prefer displaying the connect command as the first suggested command alternative.
alts = append(alts, formatDatabaseConnectCommand(clusterFlag, *database))
// all db protocols support this command.
alts = append(alts, formatDatabaseProxyCommand(clusterFlag, *database))
return alts
}

const (
// dbFormatText prints database configuration in text format.
dbFormatText = "text"
Expand All @@ -1055,36 +1081,40 @@ const (
dbFormatYAML = "yaml"
)

const (
// dbCmdUnsupportedTLSRouting is the error message printed when some
// database subcommands are not supported because TLS routing is enabled.
dbCmdUnsupportedTLSRouting = `"%v" is not supported for %v databases when TLS routing is enabled on the Teleport Proxy Service.

Please use "tsh db connect" or "tsh proxy db" to connect to the database.`

// dbCmdUnsupportedDBProtocol is the error message printed when some
// database subcommands are run against unsupported database protocols.
dbCmdUnsupportedDBProtocol = `"%v" is not supported for %v databases.

Please use "tsh db connect" or "tsh proxy db" to connect to the database.`
var (
// dbCmdUnsupportedTemplate is the error message printed when some
// database subcommands are not supported.
dbCmdUnsupportedTemplate = template.Must(template.New("").Parse(`"{{.command}}" is not supported for {{.protocol}} databases{{if .condition}} {{.condition}}{{end}}.
{{if eq (len .alternatives) 1}}
Please use the following command to connect to the database:
{{index .alternatives 0 -}}{{else}}
Please use one of the following commands to connect to the database:
{{- range .alternatives}}
{{.}}{{end -}}
{{- end}}`))
)

var (
// dbConnectTemplate is the message printed after a successful "tsh db login" on how to connect.
dbConnectTemplate = template.Must(template.New("").Parse(`Connection information for database "{{ .name }}" has been saved.

{{if .connectCommand -}}

You can now connect to it using the following command:

{{.connectCommand}}

{{end -}}
{{if .configCommand -}}
Or view the connect command for the native database CLI client:

You can view the connect command for the native database CLI client:

{{ .configCommand }}

{{end -}}
{{if .proxyCommand -}}
Or start a local proxy for database GUI clients:

You can start a local proxy for database GUI clients:

{{ .proxyCommand }}

Expand Down
4 changes: 2 additions & 2 deletions tool/tsh/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ func TestDatabaseLogin(t *testing.T) {
// Verify certificates.
certs, keys, err := decodePEM(profile.DatabaseCertPathForCluster("", test.databaseName))
require.NoError(t, err)
require.Len(t, certs, test.expectCertsLen)
require.Len(t, keys, test.expectKeysLen)
require.Equal(t, test.expectCertsLen, len(certs)) // don't use require.Len, because it spams PEM bytes on fail.
require.Equal(t, test.expectKeysLen, len(keys)) // don't use require.Len, because it spams PEM bytes on fail.
})
}

Expand Down
20 changes: 16 additions & 4 deletions tool/tsh/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,12 +379,14 @@ func onProxyCommandDB(cf *CLIConf) error {
if err != nil {
return trace.Wrap(err)
}
if err := maybeDatabaseLogin(cf, client, profile, routeToDatabase); err != nil {
return trace.Wrap(err)

// Some protocols require the --tunnel flag, e.g. Snowflake.
if !cf.LocalProxyTunnel && requiresLocalProxyTunnel(routeToDatabase.Protocol) {
return trace.BadParameter(formatDbCmdUnsupportedWithCondition(cf, routeToDatabase, "without the --tunnel flag"))
}

if routeToDatabase.Protocol == defaults.ProtocolSnowflake && !cf.LocalProxyTunnel {
return trace.BadParameter("Snowflake proxy works only in the tunnel mode. Please add --tunnel flag to enable it")
if err := maybeDatabaseLogin(cf, client, profile, routeToDatabase); err != nil {
return trace.Wrap(err)
}

rootCluster, err := client.RootClusterName(cf.Context)
Expand Down Expand Up @@ -813,6 +815,16 @@ func envVarCommand(format, key, value string) (string, error) {
}
}

// requiresLocalProxyTunnel returns whether the given protocol requires a local proxy with the --tunnel flag.
func requiresLocalProxyTunnel(protocol string) bool {
switch protocol {
case defaults.ProtocolSnowflake:
return true
default:
return false
}
}

var awsTemplateFuncs = template.FuncMap{
"envVarCommand": envVarCommand,
}
Expand Down