From 6267d9daa1f709ba84248aa8e25004c8f55024da Mon Sep 17 00:00:00 2001 From: Andrew Burke Date: Tue, 21 Nov 2023 10:21:39 -0800 Subject: [PATCH] Add callback flag to tsh login This change adds the --callback flag to tsh login, which overrides the base URL printed when doing an SSO login. --- api/utils/proxy.go | 6 +++--- api/utils/proxy_test.go | 6 +++--- lib/client/api.go | 14 ++++++++----- lib/client/redirect.go | 46 ++++++++++++++++++++++++++++++----------- lib/client/weblogin.go | 19 +++++++++++++++++ lib/utils/utils.go | 7 +++++-- lib/utils/utils_test.go | 1 + tool/tsh/common/tsh.go | 8 +++++++ 8 files changed, 82 insertions(+), 25 deletions(-) diff --git a/api/utils/proxy.go b/api/utils/proxy.go index 2f25b5ad5f906..4d88583984f92 100644 --- a/api/utils/proxy.go +++ b/api/utils/proxy.go @@ -27,7 +27,7 @@ import ( // GetProxyURL gets the HTTP proxy address to use for a given address, if any. func GetProxyURL(dialAddr string) *url.URL { - addrURL, err := parse(dialAddr) + addrURL, err := ParseURL(dialAddr) if err != nil || addrURL == nil { return nil } @@ -52,8 +52,8 @@ func GetProxyURL(dialAddr string) *url.URL { return nil } -// parse parses an absolute URL. Unlike url.Parse, absolute URLs without a scheme are allowed. -func parse(addr string) (*url.URL, error) { +// ParseURL parses an absolute URL. Unlike url.Parse, absolute URLs without a scheme are allowed. +func ParseURL(addr string) (*url.URL, error) { if addr == "" { return nil, nil } diff --git a/api/utils/proxy_test.go b/api/utils/proxy_test.go index 265a8b79cb5ef..bfcded7b49c4a 100644 --- a/api/utils/proxy_test.go +++ b/api/utils/proxy_test.go @@ -399,7 +399,7 @@ func TestParse(t *testing.T) { } for _, tc := range successTests { t.Run(fmt.Sprintf("should parse: %s", tc.name), func(t *testing.T) { - u, err := parse(tc.addr) + u, err := ParseURL(tc.addr) require.NoError(t, err) errMsg := fmt.Sprintf("(%v, %v, %v)", u.Scheme, u.Host, u.Path) require.Equal(t, tc.scheme, u.Scheme, errMsg) @@ -415,13 +415,13 @@ func TestParse(t *testing.T) { } for _, tc := range failTests { t.Run(fmt.Sprintf("should not parse: %s", tc.name), func(t *testing.T) { - u, err := parse(tc.addr) + u, err := ParseURL(tc.addr) require.Error(t, err, u) }) } t.Run("empty addr", func(t *testing.T) { - u, err := parse("") + u, err := ParseURL("") require.NoError(t, err) require.Nil(t, u) }) diff --git a/lib/client/api.go b/lib/client/api.go index a4d53c1f44183..871b267b8293a 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -360,6 +360,9 @@ type Config struct { // BindAddr is an optional host:port to bind to for SSO redirect flows. BindAddr string + // CallbackAddr is the optional base URL to give to the user when performing + // SSO redirect flows. + CallbackAddr string // NoRemoteExec will not execute a remote command after connecting to a host, // will block instead. Useful when port forwarding. Equivalent of -N for OpenSSH. @@ -4202,11 +4205,12 @@ func (tc *TeleportClient) ssoLogin(ctx context.Context, priv *keys.PrivateKey, c // ask the CA (via proxy) to sign our public key: response, err := SSHAgentSSOLogin(ctx, SSHLoginSSO{ - SSHLogin: sshLogin, - ConnectorID: connectorID, - Protocol: protocol, - BindAddr: tc.BindAddr, - Browser: tc.Browser, + SSHLogin: sshLogin, + ConnectorID: connectorID, + Protocol: protocol, + BindAddr: tc.BindAddr, + CallbackAddr: tc.CallbackAddr, + Browser: tc.Browser, }, nil) return response, trace.Wrap(err) } diff --git a/lib/client/redirect.go b/lib/client/redirect.go index dc5c63787782c..b6699e4f7c599 100644 --- a/lib/client/redirect.go +++ b/lib/client/redirect.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/trace" apidefaults "github.com/gravitational/teleport/api/defaults" + apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/secret" @@ -85,6 +86,9 @@ type Redirector struct { cancel context.CancelFunc // RedirectorConfig allows customization of Redirector RedirectorConfig + // callbackAddr is the alternate URL to give to the user during login, + // if present. + callbackAddr string } // RedirectorConfig allows customization of Redirector @@ -107,18 +111,29 @@ func NewRedirector(ctx context.Context, login SSHLoginSSO, config *RedirectorCon return nil, trace.Wrap(err) } + var callbackAddr string + if login.CallbackAddr != "" { + callbackURL, err := apiutils.ParseURL(login.CallbackAddr) + if err != nil { + return nil, trace.Wrap(err) + } + callbackURL.Scheme = "https" + callbackAddr = callbackURL.String() + } + ctxCancel, cancel := context.WithCancel(ctx) rd := &Redirector{ - context: ctxCancel, - cancel: cancel, - proxyClient: clt, - proxyURL: proxyURL, - SSHLoginSSO: login, - mux: http.NewServeMux(), - key: key, - shortPath: "/" + uuid.New().String(), - responseC: make(chan *auth.SSHLoginResponse, 1), - errorC: make(chan error, 1), + context: ctxCancel, + cancel: cancel, + proxyClient: clt, + proxyURL: proxyURL, + SSHLoginSSO: login, + mux: http.NewServeMux(), + key: key, + shortPath: "/" + uuid.New().String(), + responseC: make(chan *auth.SSHLoginResponse, 1), + errorC: make(chan error, 1), + callbackAddr: callbackAddr, } if config != nil { @@ -167,7 +182,7 @@ func (rd *Redirector) Start() error { log.Infof("Waiting for response at: %v.", rd.server.URL) // communicate callback redirect URL to the Teleport Proxy - u, err := url.Parse(rd.server.URL + "/callback") + u, err := url.Parse(rd.baseURL() + "/callback") if err != nil { return trace.Wrap(err) } @@ -226,7 +241,14 @@ func (rd *Redirector) ClickableURL() string { if rd.server == nil { return "" } - return utils.ClickableURL(rd.server.URL + rd.shortPath) + return utils.ClickableURL(rd.baseURL() + rd.shortPath) +} + +func (rd *Redirector) baseURL() string { + if rd.callbackAddr != "" { + return rd.callbackAddr + } + return rd.server.URL } // ResponseC returns a channel with response diff --git a/lib/client/weblogin.go b/lib/client/weblogin.go index a29bb41171faf..c08c3a6399997 100644 --- a/lib/client/weblogin.go +++ b/lib/client/weblogin.go @@ -48,6 +48,7 @@ import ( "github.com/gravitational/teleport/api/mfa" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" + "github.com/gravitational/teleport/api/utils/prompt" "github.com/gravitational/teleport/lib/auth" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes" @@ -55,6 +56,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/httplib/csrf" + "github.com/gravitational/teleport/lib/utils" websession "github.com/gravitational/teleport/lib/web/session" ) @@ -248,6 +250,9 @@ type SSHLoginSSO struct { // BindAddr is an optional host:port address to bind // to for SSO login flows BindAddr string + // CallbackAddr is the optional base URL to give to the user when performing + // SSO redirect flows. + CallbackAddr string // Browser can be used to pass the name of a browser to override the system // default (not currently implemented), or set to 'none' to suppress // browser opening entirely. @@ -378,6 +383,20 @@ func initClient(proxyAddr string, insecure bool, pool *x509.CertPool, extraHeade // SSHAgentSSOLogin is used by tsh to fetch user credentials using OpenID Connect (OIDC) or SAML. func SSHAgentSSOLogin(ctx context.Context, login SSHLoginSSO, config *RedirectorConfig) (*auth.SSHLoginResponse, error) { + if login.CallbackAddr != "" && !utils.AsBool(os.Getenv("TELEPORT_LOGIN_SKIP_REMOTE_HOST_WARNING")) { + const callbackPrompt = "Logging in from a remote host means that credentials will be stored on " + + "the remote host. Make sure that you trust the provided callback host " + + "(%v) and that it resolves to the provided bind addr (%v). Continue?" + ok, err := prompt.Confirmation(ctx, os.Stderr, prompt.NewContextReader(os.Stdin), + fmt.Sprintf(callbackPrompt, login.CallbackAddr, login.BindAddr), + ) + if err != nil { + return nil, trace.Wrap(err) + } + if !ok { + return nil, trace.BadParameter("Login canceled.") + } + } rd, err := NewRedirector(ctx, login, config) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 7974fc5105174..042a126fd3262 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -176,11 +176,14 @@ func ClickableURL(in string) string { return in } ip := net.ParseIP(host) - // if address is not an IP, unspecified, e.g. all interfaces 0.0.0.0 or multicast, + // If address is not an IP address, return it unchanged. + if ip == nil && out.Host != "" { + return out.String() + } + // if address is unspecified, e.g. all interfaces 0.0.0.0 or multicast, // replace with localhost that is clickable if len(ip) == 0 || ip.IsUnspecified() || ip.IsMulticast() { out.Host = fmt.Sprintf("127.0.0.1:%v", port) - return out.String() } return out.String() } diff --git a/lib/utils/utils_test.go b/lib/utils/utils_test.go index 7f5b444a1aa58..75fe9a0f006d4 100644 --- a/lib/utils/utils_test.go +++ b/lib/utils/utils_test.go @@ -181,6 +181,7 @@ func TestClickableURL(t *testing.T) { {info: "unspecified IPV4", in: "http://0.0.0.0:5050/howdy", out: "http://127.0.0.1:5050/howdy"}, {info: "specified IPV4", in: "http://192.168.1.1:5050/howdy", out: "http://192.168.1.1:5050/howdy"}, {info: "specified IPV6", in: "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:5050/howdy", out: "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:5050/howdy"}, + {info: "hostname", in: "http://example.com:3000/howdy", out: "http://example.com:3000/howdy"}, } for _, testCase := range testCases { t.Run(testCase.info, func(t *testing.T) { diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index aa30b82849250..51559f244b5cd 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -257,6 +257,9 @@ type CLIConf struct { // BindAddr is an address in the form of host:port to bind to // during `tsh login` command BindAddr string + // CallbackAddr is the optional base URL to give to the user when performing + // SSO redirect flows. + CallbackAddr string // AuthConnector is the name of the connector to use. AuthConnector string @@ -698,6 +701,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { Default("true"). BoolVar(&cf.EnableEscapeSequences) app.Flag("bind-addr", "Override host:port used when opening a browser for cluster logins").Envar(bindAddrEnvVar).StringVar(&cf.BindAddr) + app.Flag("callback", "Override the base URL (host:port) of the link shown when opening a browser for cluster logins. Must be used with --bind-addr.").StringVar(&cf.CallbackAddr) app.Flag("browser-login", browserHelp).Hidden().Envar(browserEnvVar).StringVar(&cf.Browser) modes := []string{mfaModeAuto, mfaModeCrossPlatform, mfaModePlatform, mfaModeOTP} app.Flag("mfa-mode", fmt.Sprintf("Preferred mode for MFA and Passwordless assertions (%v)", strings.Join(modes, ", "))). @@ -3720,6 +3724,10 @@ func loadClientConfigFromCLIConf(cf *CLIConf, proxy string) (*client.Config, err c.HostKeyCallback = client.InsecureSkipHostKeyChecking } c.BindAddr = cf.BindAddr + if cf.CallbackAddr != "" && cf.BindAddr == "" { + return nil, trace.BadParameter("--callback must be used with --bind-addr") + } + c.CallbackAddr = cf.CallbackAddr // Don't execute remote command, used when port forwarding. c.NoRemoteExec = cf.NoRemoteExec