diff --git a/tool/tsh/db.go b/tool/tsh/db.go index a9592af9812c0..b52f1fd05f18c 100644 --- a/tool/tsh/db.go +++ b/tool/tsh/db.go @@ -17,6 +17,7 @@ limitations under the License. package main import ( + "bytes" "context" "encoding/base64" "fmt" @@ -402,17 +403,11 @@ func onDatabaseEnv(cf *CLIConf) error { } if !dbprofile.IsSupported(*database) { - return trace.BadParameter(dbCmdUnsupportedDBProtocol, - cf.CommandWithBinary(), - defaults.ReadableDatabaseProtocol(database.Protocol), - ) + return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, database)) } - // MySQL requires ALPN local proxy in signle port mode. + // MySQL requires ALPN local proxy in single port mode. if tc.TLSRoutingEnabled && database.Protocol == defaults.ProtocolMySQL { - return trace.BadParameter(dbCmdUnsupportedTLSRouting, - cf.CommandWithBinary(), - defaults.ReadableDatabaseProtocol(database.Protocol), - ) + return trace.BadParameter(formatDbCmdUnsupportedTLSRouting(cf, database)) } env, err := dbprofile.Env(tc, *database) @@ -469,17 +464,11 @@ func onDatabaseConfig(cf *CLIConf) error { // the remote proxy directly. Return errors here when direct connection // does NOT work (e.g. when ALPN local proxy is required). if isLocalProxyAlwaysRequired(database.Protocol) { - return trace.BadParameter(dbCmdUnsupportedDBProtocol, - cf.CommandWithBinary(), - defaults.ReadableDatabaseProtocol(database.Protocol), - ) + return trace.BadParameter(formatDbCmdUnsupportedDBProtocol(cf, database)) } - // MySQL requires ALPN local proxy in signle port mode. + // MySQL requires ALPN local proxy in single port mode. if tc.TLSRoutingEnabled && database.Protocol == defaults.ProtocolMySQL { - return trace.BadParameter(dbCmdUnsupportedTLSRouting, - cf.CommandWithBinary(), - defaults.ReadableDatabaseProtocol(database.Protocol), - ) + return trace.BadParameter(formatDbCmdUnsupportedTLSRouting(cf, database)) } host, port := tc.DatabaseProxyHostPort(*database) @@ -560,9 +549,10 @@ func maybeStartLocalProxy(ctx context.Context, cf *CLIConf, tc *client.TeleportC return []dbcmd.ConnectCommandFunc{}, nil } - // Some protocols (Snowflake, Elasticsearch) only works in the local tunnel mode. + // Some protocols (Snowflake) only work in the local tunnel mode. + // ElasticSearch can work without the --tunnel flag, but not via `tsh db connect`. localProxyTunnel := cf.LocalProxyTunnel - if db.Protocol == defaults.ProtocolSnowflake || db.Protocol == defaults.ProtocolElasticsearch { + if requiresLocalProxyTunnel(db.Protocol) || db.Protocol == defaults.ProtocolElasticsearch { localProxyTunnel = true } @@ -1044,6 +1034,42 @@ func isLocalProxyAlwaysRequired(protocol string) bool { } } +// formatDbCmdUnsupportedWithCondition is a helper func that formats a generic unsupported DB error message. +// The condition argument is optional and can be "", but otherwise it should be a specific condition for which this DB subcommand +// is not supported, e.g. "when TLS routing is enabled" or "without using the --tunnel flag". +func formatDbCmdUnsupportedWithCondition(cf *CLIConf, database *tlsca.RouteToDatabase, condition string) string { + templateData := map[string]any{ + "command": cf.CommandWithBinary(), + "protocol": defaults.ReadableDatabaseProtocol(database.Protocol), + "alternatives": getDbCmdAlternatives(cf.SiteName, database), + "condition": condition, + } + + buf := bytes.NewBuffer(nil) + _ = dbCmdUnsupportedTemplate.Execute(buf, templateData) + return buf.String() +} + +// formatDbCmdUnsupportedDBProtocol is a helper func that formats the unsupported DB protocol error message unconditionally. +func formatDbCmdUnsupportedDBProtocol(cf *CLIConf, database *tlsca.RouteToDatabase) string { + return formatDbCmdUnsupportedWithCondition(cf, database, "") +} + +// formatDbCmdUnsupportedTLSRouting is a helper func that formats an unsupported DB Protocol error with a TLS routing condition. +func formatDbCmdUnsupportedTLSRouting(cf *CLIConf, database *tlsca.RouteToDatabase) string { + return formatDbCmdUnsupportedWithCondition(cf, database, "when TLS routing is enabled on the Teleport Proxy Service") +} + +// getDbCmdAlternatives is a helper func that returns alternative tsh commands for connecting to a database. +func getDbCmdAlternatives(clusterFlag string, database *tlsca.RouteToDatabase) []string { + var alts []string + // prefer displaying the connect command as the first suggested command alternative. + alts = append(alts, formatDatabaseConnectCommand(clusterFlag, *database)) + // all db protocols support this command. + alts = append(alts, formatDatabaseProxyCommand(clusterFlag, *database)) + return alts +} + const ( // dbFormatText prints database configuration in text format. dbFormatText = "text" @@ -1055,36 +1081,40 @@ const ( dbFormatYAML = "yaml" ) -const ( - // dbCmdUnsupportedTLSRouting is the error message printed when some - // database subcommands are not supported because TLS routing is enabled. - dbCmdUnsupportedTLSRouting = `"%v" is not supported for %v databases when TLS routing is enabled on the Teleport Proxy Service. - -Please use "tsh db connect" or "tsh proxy db" to connect to the database.` - - // dbCmdUnsupportedDBProtocol is the error message printed when some - // database subcommands are run against unsupported database protocols. - dbCmdUnsupportedDBProtocol = `"%v" is not supported for %v databases. - -Please use "tsh db connect" or "tsh proxy db" to connect to the database.` +var ( + // dbCmdUnsupportedTemplate is the error message printed when some + // database subcommands are not supported. + dbCmdUnsupportedTemplate = template.Must(template.New("").Parse(`"{{.command}}" is not supported for {{.protocol}} databases{{if .condition}} {{.condition}}{{end}}. +{{if eq (len .alternatives) 1}} +Please use the following command to connect to the database: + {{index .alternatives 0 -}}{{else}} +Please use one of the following commands to connect to the database: + {{- range .alternatives}} + {{.}}{{end -}} +{{- end}}`)) ) var ( // dbConnectTemplate is the message printed after a successful "tsh db login" on how to connect. dbConnectTemplate = template.Must(template.New("").Parse(`Connection information for database "{{ .name }}" has been saved. +{{if .connectCommand -}} + You can now connect to it using the following command: {{.connectCommand}} +{{end -}} {{if .configCommand -}} -Or view the connect command for the native database CLI client: + +You can view the connect command for the native database CLI client: {{ .configCommand }} {{end -}} {{if .proxyCommand -}} -Or start a local proxy for database GUI clients: + +You can start a local proxy for database GUI clients: {{ .proxyCommand }} diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index 985706e15e80c..f7261d17cc073 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -130,8 +130,8 @@ func TestDatabaseLogin(t *testing.T) { // Verify certificates. certs, keys, err := decodePEM(profile.DatabaseCertPathForCluster("", test.databaseName)) require.NoError(t, err) - require.Len(t, certs, test.expectCertsLen) - require.Len(t, keys, test.expectKeysLen) + require.Equal(t, test.expectCertsLen, len(certs)) // don't use require.Len, because it spams PEM bytes on fail. + require.Equal(t, test.expectKeysLen, len(keys)) // don't use require.Len, because it spams PEM bytes on fail. }) } diff --git a/tool/tsh/proxy.go b/tool/tsh/proxy.go index 1e11314a2e4e3..0294096492e79 100644 --- a/tool/tsh/proxy.go +++ b/tool/tsh/proxy.go @@ -379,12 +379,14 @@ func onProxyCommandDB(cf *CLIConf) error { if err != nil { return trace.Wrap(err) } - if err := maybeDatabaseLogin(cf, client, profile, routeToDatabase); err != nil { - return trace.Wrap(err) + + // Some protocols require the --tunnel flag, e.g. Snowflake. + if !cf.LocalProxyTunnel && requiresLocalProxyTunnel(routeToDatabase.Protocol) { + return trace.BadParameter(formatDbCmdUnsupportedWithCondition(cf, routeToDatabase, "without the --tunnel flag")) } - if routeToDatabase.Protocol == defaults.ProtocolSnowflake && !cf.LocalProxyTunnel { - return trace.BadParameter("Snowflake proxy works only in the tunnel mode. Please add --tunnel flag to enable it") + if err := maybeDatabaseLogin(cf, client, profile, routeToDatabase); err != nil { + return trace.Wrap(err) } rootCluster, err := client.RootClusterName(cf.Context) @@ -813,6 +815,16 @@ func envVarCommand(format, key, value string) (string, error) { } } +// requiresLocalProxyTunnel returns whether the given protocol requires a local proxy with the --tunnel flag. +func requiresLocalProxyTunnel(protocol string) bool { + switch protocol { + case defaults.ProtocolSnowflake: + return true + default: + return false + } +} + var awsTemplateFuncs = template.FuncMap{ "envVarCommand": envVarCommand, }