diff --git a/integration/appaccess/pack.go b/integration/appaccess/pack.go index a10ef45bacfad..095392b596c73 100644 --- a/integration/appaccess/pack.go +++ b/integration/appaccess/pack.go @@ -55,6 +55,7 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web" "github.com/gravitational/teleport/lib/web/app" + websession "github.com/gravitational/teleport/lib/web/session" ) // Pack contains identity as well as initialized Teleport clusters and instances. @@ -243,7 +244,7 @@ func (p *Pack) initWebSession(t *testing.T) { // Extract session cookie and bearer token. require.Len(t, resp.Cookies(), 1) cookie := resp.Cookies()[0] - require.Equal(t, cookie.Name, web.CookieName) + require.Equal(t, cookie.Name, websession.CookieName) p.webCookie = cookie.Value p.webToken = csResp.Token @@ -347,7 +348,7 @@ func (p *Pack) makeWebapiRequest(method, endpoint string, payload []byte) (int, } req.AddCookie(&http.Cookie{ - Name: web.CookieName, + Name: websession.CookieName, Value: p.webCookie, }) req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", p.webToken)) diff --git a/integration/helpers/instance.go b/integration/helpers/instance.go index e7ac22bc04185..1df4164aa70d5 100644 --- a/integration/helpers/instance.go +++ b/integration/helpers/instance.go @@ -64,6 +64,7 @@ import ( "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web" + websession "github.com/gravitational/teleport/lib/web/session" ) const ( @@ -1413,8 +1414,8 @@ func (i *TeleInstance) NewWebClient(cfg ClientConfig) (*WebClient, error) { return nil, trace.BadParameter("unexpected number of cookies returned; got %d, want %d", len(cookies), 1) } cookie := cookies[0] - if cookie.Name != web.CookieName { - return nil, trace.BadParameter("unexpected session cookies returned; got %s, want %s", cookie.Name, web.CookieName) + if cookie.Name != websession.CookieName { + return nil, trace.BadParameter("unexpected session cookies returned; got %s, want %s", cookie.Name, websession.CookieName) } tc, err := i.NewUnauthenticatedClient(cfg) diff --git a/integration/helpers/web.go b/integration/helpers/web.go index e1cfbc85cb4ac..c9392c76a2d83 100644 --- a/integration/helpers/web.go +++ b/integration/helpers/web.go @@ -30,6 +30,7 @@ import ( "github.com/gravitational/teleport/lib/httplib/csrf" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web" + websession "github.com/gravitational/teleport/lib/web/session" "github.com/gravitational/teleport/lib/web/ui" ) @@ -91,7 +92,7 @@ func LoginWebClient(t *testing.T, host, username, password string) *WebClientPac // Extract session cookie and bearer token. require.Len(t, resp.Cookies(), 1) cookie := resp.Cookies()[0] - require.Equal(t, cookie.Name, web.CookieName) + require.Equal(t, cookie.Name, websession.CookieName) webClient := &WebClientPack{ clt: client, @@ -127,7 +128,7 @@ func (w *WebClientPack) DoRequest(t *testing.T, method, endpoint string, payload require.NoError(t, err) req.AddCookie(&http.Cookie{ - Name: web.CookieName, + Name: websession.CookieName, Value: w.webCookie, }) req.Header.Add("Authorization", fmt.Sprintf("Bearer %v", w.bearerToken)) diff --git a/lib/benchmark/web.go b/lib/benchmark/web.go new file mode 100644 index 0000000000000..6be3bfc559891 --- /dev/null +++ b/lib/benchmark/web.go @@ -0,0 +1,236 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package benchmark + +import ( + "context" + "crypto/tls" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + + apiclient "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/web" +) + +// WebSSHBenchmark is a benchmark suite that connects to the configured +// target hosts via the web api and executes the provided command. +type WebSSHBenchmark struct { + // Command to execute on the host. + Command []string + // Random whether to connect to a random host or not + Random bool + // Duration of the test used to determine if renewing web sessions + // is necessary. + Duration time.Duration +} + +// BenchBuilder returns a WorkloadFunc for the given benchmark suite. +func (s WebSSHBenchmark) BenchBuilder(ctx context.Context, tc *client.TeleportClient) (WorkloadFunc, error) { + clt, sess, err := tc.LoginWeb(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + webSess := &webSession{ + webSession: sess, + clt: clt, + } + + // The web session will expire before the duration of the test + // so launch the renewal loop. + if !time.Now().Add(s.Duration).Before(webSess.expires()) { + go webSess.renew(ctx) + } + + // Add "exit" to ensure that the session terminates after running the command. + command := strings.Join(append(s.Command, "\r\nexit\r\n"), " ") + + if s.Random { + if tc.Host != "all" { + return nil, trace.BadParameter("random ssh bench commands must use the format @all ") + } + + servers, err := s.getServers(ctx, tc) + if err != nil { + return nil, trace.Wrap(err) + } + + return func(ctx context.Context) error { + return trace.Wrap(s.runCommand(ctx, tc, webSess, chooseRandomHost(servers), command)) + }, nil + } + + return func(ctx context.Context) error { + return trace.Wrap(s.runCommand(ctx, tc, webSess, tc.Host, command)) + }, nil +} + +// runCommand starts a non-interactive SSH session and executes the provided +// command before terminating the session. +func (s WebSSHBenchmark) runCommand(ctx context.Context, tc *client.TeleportClient, webSess *webSession, host, command string) error { + stream, err := s.connectToHost(ctx, tc, webSess, host) + if err != nil { + return trace.Wrap(err) + } + defer stream.Close() + + if _, err := io.WriteString(stream, command); err != nil { + return trace.Wrap(err) + } + + if _, err := io.Copy(tc.Stdout, stream); err != nil && !errors.Is(err, io.EOF) { + return trace.Wrap(err) + } + + return nil +} + +// getServers returns all [types.Server] that the authenticated user has +// access to. +func (s WebSSHBenchmark) getServers(ctx context.Context, tc *client.TeleportClient) ([]types.Server, error) { + clt, err := tc.ConnectToCluster(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + defer clt.Close() + + resources, err := apiclient.GetAllResources[types.Server](ctx, clt.AuthClient, tc.ResourceFilter(types.KindNode)) + if err != nil { + return nil, trace.Wrap(err) + } + + if len(resources) == 0 { + return nil, trace.BadParameter("no target hosts available") + } + + return resources, nil +} + +// connectToHost opens an SSH session to the target host via the Proxy web api. +func (s WebSSHBenchmark) connectToHost(ctx context.Context, tc *client.TeleportClient, webSession *webSession, host string) (*web.TerminalStream, error) { + req := web.TerminalRequest{ + Server: host, + Login: tc.HostLogin, + Term: session.TerminalParams{ + W: 100, + H: 100, + }, + } + + data, err := json.Marshal(req) + if err != nil { + return nil, trace.Wrap(err) + } + + u := url.URL{ + Host: tc.WebProxyAddr, + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", tc.SiteName), + RawQuery: url.Values{ + "params": []string{string(data)}, + roundtrip.AccessTokenQueryParam: []string{webSession.getToken()}, + }.Encode(), + } + + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: tc.InsecureSkipVerify}, + Jar: webSession.getCookieJar(), + } + + ws, resp, err := dialer.DialContext(ctx, u.String(), http.Header{ + "Origin": []string{"http://localhost"}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + defer resp.Body.Close() + + ty, _, err := ws.ReadMessage() + if err != nil { + return nil, trace.Wrap(err) + } + + if ty != websocket.BinaryMessage { + return nil, trace.BadParameter("unexpected websocket message received %d", ty) + } + + stream := web.NewTerminalStream(ctx, ws, utils.NewLogger()) + return stream, trace.Wrap(err) +} + +type webSession struct { + mu sync.Mutex + webSession types.WebSession + clt *client.WebClient +} + +func (s *webSession) renew(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(time.Until(s.expires().Add(-3 * time.Minute))): + resp, err := s.clt.PostJSON(ctx, s.clt.Endpoint("webapi", "sessions", "renew"), nil) + if err != nil { + continue + } + + session, err := client.GetSessionFromResponse(resp) + if err != nil { + continue + } + + s.mu.Lock() + s.webSession = session + s.mu.Unlock() + } + } +} + +func (s *webSession) expires() time.Time { + s.mu.Lock() + defer s.mu.Unlock() + + return s.webSession.GetBearerTokenExpiryTime() +} + +func (s *webSession) getCookieJar() http.CookieJar { + s.mu.Lock() + defer s.mu.Unlock() + + return s.clt.HTTPClient().Jar +} + +func (s *webSession) getToken() string { + s.mu.Lock() + defer s.mu.Unlock() + + return s.webSession.GetBearerToken() +} diff --git a/lib/client/api.go b/lib/client/api.go index d782f66621c5b..c58a56c6c6662 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -3297,6 +3297,35 @@ func (tc *TeleportClient) Login(ctx context.Context) (*Key, error) { return key, nil } +// LoginWeb logs the user in via the Teleport web api the same way that the web UI does. +func (tc *TeleportClient) LoginWeb(ctx context.Context) (*WebClient, types.WebSession, error) { + ctx, span := tc.Tracer.Start( + ctx, + "teleportClient/LoginWeb", + oteltrace.WithSpanKind(oteltrace.SpanKindClient), + ) + defer span.End() + + // Ping the endpoint to see if it's up and find the type of authentication + // supported, also show the message of the day if available. + pr, err := tc.Ping(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // Perform the ALPN test once at login. + tc.TLSRoutingConnUpgradeRequired = client.IsALPNConnUpgradeRequired(ctx, tc.WebProxyAddr, tc.InsecureSkipVerify) + + // Get the SSHLoginFunc that matches client and cluster settings. + webLoginFunc, err := tc.getWebLoginFunc(pr) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + clt, session, err := tc.webLogin(ctx, webLoginFunc) + return clt, session, trace.Wrap(err) +} + // AttemptDeviceLogin attempts device authentication for the current device. // It expects to receive the latest activated key, as acquired via // [TeleportClient.Login], and augments the certificates within the key with @@ -3473,6 +3502,139 @@ func (tc *TeleportClient) getSSHLoginFunc(pr *webclient.PingResponse) (SSHLoginF } } +// getWebLoginFunc returns an WebLoginFunc that matches client and cluster settings. +func (tc *TeleportClient) getWebLoginFunc(pr *webclient.PingResponse) (WebLoginFunc, error) { + switch pr.Auth.Type { + case constants.Local: + switch pr.Auth.Local.Name { + case constants.PasswordlessConnector: + // Sanity check settings. + if !pr.Auth.AllowPasswordless { + return nil, trace.BadParameter("passwordless disallowed by cluster settings") + } + return tc.pwdlessLoginWeb, nil + case constants.HeadlessConnector: + return nil, trace.BadParameter("headless logins not allowed for web sessions") + case constants.LocalConnector, "": + // if passwordless is enabled and there are passwordless credentials + // registered, we can try to go with passwordless login even though + // auth=local was selected. + if tc.canDefaultToPasswordless(pr) { + log.Debug("Trying passwordless login because credentials were found") + return tc.pwdlessLoginWeb, nil + } + + return func(ctx context.Context, priv *keys.PrivateKey) (*WebClient, types.WebSession, error) { + return tc.localLoginWeb(ctx, priv, pr.Auth.SecondFactor) + }, nil + default: + return nil, trace.BadParameter("unsupported authentication connector type: %q", pr.Auth.Local.Name) + } + case constants.OIDC: + return nil, trace.NotImplemented("SSO login not supported") + case constants.SAML: + return nil, trace.NotImplemented("SSO login not supported") + case constants.Github: + return nil, trace.NotImplemented("SSO login not supported") + default: + return nil, trace.BadParameter("unsupported authentication type: %q", pr.Auth.Type) + } +} + +// pwdlessLoginWeb performs a passwordless ceremony and then makes a request to authenticate via the web api. +func (tc *TeleportClient) pwdlessLoginWeb(ctx context.Context, priv *keys.PrivateKey) (*WebClient, types.WebSession, error) { + // Only pass on the user if explicitly set, otherwise let the credential + // picker kick in. + user := "" + if tc.ExplicitUsername { + user = tc.Username + } + + sshLogin, err := tc.newSSHLogin(priv) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + clt, session, err := SSHAgentPasswordlessLoginWeb(ctx, SSHLoginPasswordless{ + SSHLogin: sshLogin, + User: user, + AuthenticatorAttachment: tc.AuthenticatorAttachment, + StderrOverride: tc.Stderr, + }) + return clt, session, trace.Wrap(err) +} + +// localLoginWeb performs the mfa ceremony and then makes a request to authenticate via the web api. +func (tc *TeleportClient) localLoginWeb(ctx context.Context, priv *keys.PrivateKey, secondFactor constants.SecondFactorType) (*WebClient, types.WebSession, error) { + // TODO(awly): mfa: ideally, clients should always go through mfaLocalLogin + // (with a nop MFA challenge if no 2nd factor is required). That way we can + // deprecate the direct login endpoint. + switch secondFactor { + case constants.SecondFactorOff, constants.SecondFactorOTP: + clt, session, err := tc.directLoginWeb(ctx, secondFactor, priv) + return clt, session, trace.Wrap(err) + case constants.SecondFactorU2F, constants.SecondFactorWebauthn, constants.SecondFactorOn, constants.SecondFactorOptional: + clt, session, err := tc.mfaLocalLoginWeb(ctx, priv) + return clt, session, trace.Wrap(err) + default: + return nil, nil, trace.BadParameter("unsupported second factor type: %q", secondFactor) + } +} + +// directLoginWeb asks for a password + OTP token then makes a request to authenticate via the web api. +func (tc *TeleportClient) directLoginWeb(ctx context.Context, secondFactorType constants.SecondFactorType, priv *keys.PrivateKey) (*WebClient, types.WebSession, error) { + password, err := tc.AskPassword(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // Only ask for a second factor if it's enabled. + var otpToken string + if secondFactorType == constants.SecondFactorOTP { + otpToken, err = tc.AskOTP(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + } + + sshLogin, err := tc.newSSHLogin(priv) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + // authenticate via the web api + clt, session, err := SSHAgentLoginWeb(ctx, SSHLoginDirect{ + SSHLogin: sshLogin, + User: tc.Username, + Password: password, + OTPToken: otpToken, + }) + return clt, session, trace.Wrap(err) +} + +// mfaLocalLoginWeb asks for a password and performs the challenge-response authentication with the web api +func (tc *TeleportClient) mfaLocalLoginWeb(ctx context.Context, priv *keys.PrivateKey) (*WebClient, types.WebSession, error) { + password, err := tc.AskPassword(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + sshLogin, err := tc.newSSHLogin(priv) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + clt, session, err := SSHAgentMFAWebSessionLogin(ctx, SSHLoginMFA{ + SSHLogin: sshLogin, + User: tc.Username, + Password: password, + AuthenticatorAttachment: tc.AuthenticatorAttachment, + PreferOTP: tc.PreferOTP, + AllowStdinHijack: tc.AllowStdinHijack, + }) + return clt, session, trace.Wrap(err) +} + // hasTouchIDCredentials provides indirection for tests. var hasTouchIDCredentials = touchid.HasCredentials @@ -3566,6 +3728,38 @@ func (tc *TeleportClient) SSHLogin(ctx context.Context, sshLoginFunc SSHLoginFun return key, nil } +// WebLoginFunc is a function which carries out authn with the web server and returns a web session and cookies. +type WebLoginFunc func(context.Context, *keys.PrivateKey) (*WebClient, types.WebSession, error) + +// webLogin uses the given login function to log the client in via the web api. +func (tc *TeleportClient) webLogin(ctx context.Context, webLoginFunc WebLoginFunc) (*WebClient, types.WebSession, error) { + priv, err := tc.GetNewLoginKey(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + clt, session, err := webLoginFunc(ctx, priv) + if err != nil { + // check if the error is a private key policy error, and relogin if it is. + if privateKeyPolicy, parseErr := keys.ParsePrivateKeyPolicyError(err); parseErr == nil { + // The current private key was rejected due to an unmet key policy requirement. + fmt.Fprintf(tc.Stderr, "Unmet private key policy %q.\n", privateKeyPolicy) + + // Set the private key policy to the expected value and re-login. + tc.PrivateKeyPolicy = privateKeyPolicy + priv, err = tc.GetNewLoginKey(ctx) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + fmt.Fprintf(tc.Stderr, "Re-initiating login with YubiKey generated private key.\n") + clt, session, err = webLoginFunc(ctx, priv) + } + } + + return clt, session, trace.Wrap(err) +} + // GetNewLoginKey gets a new private key for login. func (tc *TeleportClient) GetNewLoginKey(ctx context.Context) (priv *keys.PrivateKey, err error) { switch tc.PrivateKeyPolicy { diff --git a/lib/client/weblogin.go b/lib/client/weblogin.go index 6ca06870ef304..74831580bdb7b 100644 --- a/lib/client/weblogin.go +++ b/lib/client/weblogin.go @@ -19,11 +19,15 @@ package client import ( "bytes" "context" + "crypto/rand" "crypto/x509" + "encoding/hex" "encoding/json" "fmt" "io" "net" + "net/http" + "net/http/cookiejar" "net/url" "os" "os/exec" @@ -45,6 +49,9 @@ import ( wanlib "github.com/gravitational/teleport/lib/auth/webauthn" wancli "github.com/gravitational/teleport/lib/auth/webauthncli" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/httplib" + "github.com/gravitational/teleport/lib/httplib/csrf" + websession "github.com/gravitational/teleport/lib/web/session" ) const ( @@ -301,7 +308,7 @@ type SSHLoginHeadless struct { } // initClient creates a new client to the HTTPS web proxy. -func initClient(proxyAddr string, insecure bool, pool *x509.CertPool, extraHeaders map[string]string) (*WebClient, *url.URL, error) { +func initClient(proxyAddr string, insecure bool, pool *x509.CertPool, extraHeaders map[string]string, opts ...roundtrip.ClientParam) (*WebClient, *url.URL, error) { log := logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentClient, }) @@ -326,8 +333,16 @@ func initClient(proxyAddr string, insecure bool, pool *x509.CertPool, extraHeade fmt.Fprintf(os.Stderr, "WARNING: You are using insecure connection to Teleport proxy %v\n", proxyAddr) } - opt := roundtrip.HTTPClient(newClient(insecure, pool, extraHeaders)) - clt, err := NewWebClient(proxyAddr, opt) + jar, err := cookiejar.New(nil) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + opts = append(opts, + roundtrip.HTTPClient(newClient(insecure, pool, extraHeaders)), + roundtrip.CookieJar(jar), + ) + clt, err := NewWebClient(proxyAddr, opts...) if err != nil { return nil, nil, trace.Wrap(err) } @@ -658,3 +673,248 @@ func GetWebConfig(ctx context.Context, proxyAddr string, insecure bool) (*webcli return &cfg, nil } + +// CreateWebSessionReq is a request for the web api to +// initiate a new web session. +type CreateWebSessionReq struct { + // User is the Teleport username. + User string `json:"user"` + // Pass is the password. + Pass string `json:"pass"` + // SecondFactorToken is the OTP. + SecondFactorToken string `json:"second_factor_token"` +} + +// CreateWebSessionResponse is a response from the web api +// to a [CreateWebSessionReq] request. +type CreateWebSessionResponse struct { + // TokenType is token type (bearer) + TokenType string `json:"type"` + // Token value + Token string `json:"token"` + // TokenExpiresIn sets seconds before this token is not valid + TokenExpiresIn int `json:"expires_in"` + // SessionExpires is when this session expires. + SessionExpires time.Time `json:"sessionExpires,omitempty"` + // SessionInactiveTimeoutMS specifies how long in milliseconds + // a user WebUI session can be left idle before being logged out + // by the server. A zero value means there is no idle timeout set. + SessionInactiveTimeoutMS int `json:"sessionInactiveTimeout"` +} + +// SSHAgentLoginWeb is used by tsh to fetch local user credentials via the web api. +func SSHAgentLoginWeb(ctx context.Context, login SSHLoginDirect) (*WebClient, types.WebSession, error) { + clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + token := make([]byte, 32) + if _, err := rand.Read(token); err != nil { + return nil, nil, trace.Wrap(err) + } + + csrfToken := hex.EncodeToString(token) + resp, err := httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(&CreateWebSessionReq{ + User: login.User, + Pass: login.Password, + SecondFactorToken: login.OTPToken, + }); err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "POST", clt.Endpoint("webapi", "sessions", "web"), &buf) + if err != nil { + return nil, err + } + + cookie := &http.Cookie{ + Name: csrf.CookieName, + Value: csrfToken, + } + + req.AddCookie(cookie) + + req.Header.Set("Content-Type", "application/json") + req.Header.Set(csrf.HeaderName, csrfToken) + return clt.HTTPClient().Do(req) + })) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + session, err := GetSessionFromResponse(resp) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return clt, session, nil +} + +// SSHAgentMFAWebSessionLogin requests a MFA challenge via the proxy web api. +// If the credentials are valid, the proxy will return a challenge. We then +// prompt the user to provide 2nd factor and pass the response to the proxy. +func SSHAgentMFAWebSessionLogin(ctx context.Context, login SSHLoginMFA) (*WebClient, types.WebSession, error) { + clt, _, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + beginReq := MFAChallengeRequest{ + User: login.User, + Pass: login.Password, + } + challengeJSON, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "mfa", "login", "begin"), beginReq) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + challenge := &MFAAuthenticateChallenge{} + if err := json.Unmarshal(challengeJSON.Bytes(), challenge); err != nil { + return nil, nil, trace.Wrap(err) + } + + // Convert to auth gRPC proto challenge. + challengePB := &proto.MFAAuthenticateChallenge{} + if challenge.TOTPChallenge { + challengePB.TOTP = &proto.TOTPChallenge{} + } + if challenge.WebauthnChallenge != nil { + challengePB.WebauthnChallenge = wanlib.CredentialAssertionToProto(challenge.WebauthnChallenge) + } + + respPB, err := PromptMFAChallenge(ctx, challengePB, login.ProxyAddr, &PromptMFAChallengeOpts{ + AllowStdinHijack: login.AllowStdinHijack, + AuthenticatorAttachment: login.AuthenticatorAttachment, + PreferOTP: login.PreferOTP, + }) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + challengeResp := AuthenticateWebUserRequest{ + User: login.User, + } + // Convert back from auth gRPC proto response. + switch r := respPB.Response.(type) { + case *proto.MFAAuthenticateResponse_Webauthn: + challengeResp.WebauthnAssertionResponse = wanlib.CredentialAssertionResponseFromProto(r.Webauthn) + default: + // No challenge was sent, so we send back just username/password. + } + + loginRespJSON, err := clt.PostJSON(ctx, clt.Endpoint("webapi", "mfa", "login", "finishsession"), challengeResp) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + session, err := GetSessionFromResponse(loginRespJSON) + if err != nil { + return nil, nil, trace.Wrap(err) + } + return clt, session, nil +} + +// SSHAgentPasswordlessLoginWeb requests a passwordless MFA challenge via the proxy +// web api. +func SSHAgentPasswordlessLoginWeb(ctx context.Context, login SSHLoginPasswordless) (*WebClient, types.WebSession, error) { + webClient, webURL, err := initClient(login.ProxyAddr, login.Insecure, login.Pool, login.ExtraHeaders) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + challengeJSON, err := webClient.PostJSON( + ctx, webClient.Endpoint("webapi", "mfa", "login", "begin"), + &MFAChallengeRequest{ + Passwordless: true, + }) + if err != nil { + return nil, nil, trace.Wrap(err) + } + challenge := &MFAAuthenticateChallenge{} + if err := json.Unmarshal(challengeJSON.Bytes(), challenge); err != nil { + return nil, nil, trace.Wrap(err) + } + // Sanity check WebAuthn challenge. + switch { + case challenge.WebauthnChallenge == nil: + return nil, nil, trace.BadParameter("passwordless: webauthn challenge missing") + case challenge.WebauthnChallenge.Response.UserVerification == protocol.VerificationDiscouraged: + return nil, nil, trace.BadParameter("passwordless: user verification requirement too lax (%v)", challenge.WebauthnChallenge.Response.UserVerification) + } + + stderr := login.StderrOverride + if stderr == nil { + stderr = os.Stderr + } + + prompt := login.CustomPrompt + if prompt == nil { + prompt = wancli.NewDefaultPrompt(ctx, stderr) + } + + mfaResp, _, err := promptWebauthn(ctx, webURL.String(), challenge.WebauthnChallenge, prompt, &wancli.LoginOpts{ + User: login.User, + AuthenticatorAttachment: login.AuthenticatorAttachment, + }) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + loginRespJSON, err := webClient.PostJSON( + ctx, webClient.Endpoint("webapi", "mfa", "login", "finishsession"), + &AuthenticateWebUserRequest{ + User: login.User, + WebauthnAssertionResponse: wanlib.CredentialAssertionResponseFromProto(mfaResp.GetWebauthn()), + }) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + webSession, err := GetSessionFromResponse(loginRespJSON) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return webClient, webSession, nil +} + +// GetSessionFromResponse creates a [types.WebSession] if a cookie +// named [websession.CookieName] is present in the provided [roundtrip.Response]. +func GetSessionFromResponse(resp *roundtrip.Response) (types.WebSession, error) { + var sess CreateWebSessionResponse + if err := json.Unmarshal(resp.Bytes(), &sess); err != nil { + return nil, trace.Wrap(err) + } + + cookies := resp.Cookies() + + var sessionCookie *http.Cookie + for _, cookie := range cookies { + if cookie.Name == websession.CookieName { + sessionCookie = cookie + break + } + } + + if sessionCookie == nil { + return nil, trace.BadParameter("no session cookie present") + } + + cookie, err := websession.DecodeCookie(sessionCookie.Value) + if err != nil { + return nil, trace.Wrap(err) + } + + session, err := types.NewWebSession(cookie.SID, types.KindWebSession, types.WebSessionSpecV2{ + User: cookie.User, + BearerToken: sess.Token, + BearerTokenExpires: time.Now().Add(time.Duration(sess.TokenExpiresIn) * time.Second), + Expires: sess.SessionExpires, + LoginTime: time.Now(), + IdleTimeout: types.Duration(time.Duration(sess.SessionInactiveTimeoutMS) * time.Millisecond), + }) + return session, trace.Wrap(err) +} diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 144bc865eee20..2bb1f0d4d70ea 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -87,6 +87,7 @@ import ( "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/app" + websession "github.com/gravitational/teleport/lib/web/session" "github.com/gravitational/teleport/lib/web/ui" ) @@ -1956,7 +1957,7 @@ func (h *Handler) createWebSession(w http.ResponseWriter, r *http.Request, p htt return nil, trace.Wrap(err) } - if err := SetSessionCookie(w, req.User, webSession.GetName()); err != nil { + if err := websession.SetCookie(w, req.User, webSession.GetName()); err != nil { return nil, trace.Wrap(err) } @@ -2004,7 +2005,7 @@ func (h *Handler) logout(ctx context.Context, w http.ResponseWriter, sctx *Sessi if err := sctx.Invalidate(ctx); err != nil { return trace.Wrap(err) } - ClearSession(w) + websession.ClearCookie(w) return nil } @@ -2044,7 +2045,7 @@ func (h *Handler) renewWebSession(w http.ResponseWriter, r *http.Request, params if err != nil { return nil, trace.Wrap(err) } - if err := SetSessionCookie(w, newSession.GetUser(), newSession.GetName()); err != nil { + if err := websession.SetCookie(w, newSession.GetUser(), newSession.GetName()); err != nil { return nil, trace.Wrap(err) } @@ -2131,7 +2132,7 @@ func (h *Handler) changeUserAuthentication(w http.ResponseWriter, r *http.Reques h.log.WithError(err).Error("Failed to set passwordless as connector name.") } - if err := SetSessionCookie(w, sess.GetUser(), sess.GetName()); err != nil { + if err := websession.SetCookie(w, sess.GetUser(), sess.GetName()); err != nil { return nil, trace.Wrap(err) } @@ -2352,7 +2353,7 @@ func (h *Handler) mfaLoginFinishSession(w http.ResponseWriter, r *http.Request, // Fetch user from session, user is empty for passwordless requests. user := session.GetUser() - if err := SetSessionCookie(w, user, session.GetName()); err != nil { + if err := websession.SetCookie(w, user, session.GetName()); err != nil { return nil, trace.Wrap(err) } @@ -3823,17 +3824,17 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error { // and bearer token func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { const missingCookieMsg = "missing session cookie" - cookie, err := r.Cookie(CookieName) + cookie, err := r.Cookie(websession.CookieName) if err != nil || (cookie != nil && cookie.Value == "") { return nil, trace.AccessDenied(missingCookieMsg) } - decodedCookie, err := DecodeCookie(cookie.Value) + decodedCookie, err := websession.DecodeCookie(cookie.Value) if err != nil { return nil, trace.AccessDenied("failed to decode cookie") } ctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID) if err != nil { - ClearSession(w) + websession.ClearCookie(w) return nil, trace.AccessDenied("need auth") } if checkBearerToken { @@ -4028,7 +4029,7 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp } } - if err := SetSessionCookie(w, response.Username, response.SessionName); err != nil { + if err := websession.SetCookie(w, response.Username, response.SessionName); err != nil { return trace.Wrap(err) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index e0c628deaa096..9d7979e1a2984 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -121,6 +121,7 @@ import ( "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + websession "github.com/gravitational/teleport/lib/web/session" "github.com/gravitational/teleport/lib/web/ui" ) @@ -4867,7 +4868,7 @@ func TestCreateAppSession(t *testing.T) { rawCookie := *pack.cookies[0] cookieBytes, err := hex.DecodeString(rawCookie.Value) require.NoError(t, err) - var sessionCookie SessionCookie + var sessionCookie websession.Cookie err = json.Unmarshal(cookieBytes, &sessionCookie) require.NoError(t, err) @@ -5032,7 +5033,7 @@ func TestCreateAppSessionHealthCheckAppServer(t *testing.T) { rawCookie := *pack.cookies[0] cookieBytes, err := hex.DecodeString(rawCookie.Value) require.NoError(t, err) - var sessionCookie SessionCookie + var sessionCookie websession.Cookie err = json.Unmarshal(cookieBytes, &sessionCookie) require.NoError(t, err) diff --git a/lib/web/cookie.go b/lib/web/session/cookie.go similarity index 65% rename from lib/web/cookie.go rename to lib/web/session/cookie.go index d84d567e65baf..c3fbb24a6e919 100644 --- a/lib/web/cookie.go +++ b/lib/web/session/cookie.go @@ -15,7 +15,7 @@ limitations under the License. */ -package web +package session import ( "encoding/hex" @@ -23,33 +23,39 @@ import ( "net/http" ) -// SessionCookie stores information about active user and session -type SessionCookie struct { +// Cookie stores information about active user and session +type Cookie struct { User string `json:"user"` SID string `json:"sid"` } +// EncodeCookie returns the string representation of a [Cookie] +// that should be used to store the user session in the cookies +// of a [http.ResponseWriter]. func EncodeCookie(user, sid string) (string, error) { - bytes, err := json.Marshal(SessionCookie{User: user, SID: sid}) + bytes, err := json.Marshal(Cookie{User: user, SID: sid}) if err != nil { return "", err } return hex.EncodeToString(bytes), nil } -func DecodeCookie(b string) (*SessionCookie, error) { +// DecodeCookie returns the [Cookie] from the provided string. +func DecodeCookie(b string) (*Cookie, error) { bytes, err := hex.DecodeString(b) if err != nil { return nil, err } - var c *SessionCookie + var c *Cookie if err := json.Unmarshal(bytes, &c); err != nil { return nil, err } return c, nil } -func SetSessionCookie(w http.ResponseWriter, user, sid string) error { +// SetCookie encodes the provided user and session id via [EncodeCookie] +// and then sets the [http.Cookie] of the provided [http.ResponseWriter]. +func SetCookie(w http.ResponseWriter, user, sid string) error { d, err := EncodeCookie(user, sid) if err != nil { return err @@ -65,7 +71,8 @@ func SetSessionCookie(w http.ResponseWriter, user, sid string) error { return nil } -func ClearSession(w http.ResponseWriter) { +// ClearCookie wipes the session cookie to invalidate the user session. +func ClearCookie(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{ Name: CookieName, Value: "", diff --git a/lib/web/session/cookie_test.go b/lib/web/session/cookie_test.go new file mode 100644 index 0000000000000..1b41d80ded029 --- /dev/null +++ b/lib/web/session/cookie_test.go @@ -0,0 +1,50 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package session + +import ( + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestCookies(t *testing.T) { + const ( + user = "llama" + sessionID = "98765" + ) + expectedCookie := &Cookie{User: user, SID: sessionID} + + encodedCookie, err := EncodeCookie(user, sessionID) + require.NoError(t, err) + + cookie, err := DecodeCookie(encodedCookie) + require.NoError(t, err) + require.Equal(t, expectedCookie, cookie) + + recorder := httptest.NewRecorder() + require.Empty(t, recorder.Header().Get("Set-Cookie")) + + require.NoError(t, SetCookie(recorder, user, sessionID)) + ClearCookie(recorder) + setCookies := recorder.Header().Values("Set-Cookie") + require.Len(t, setCookies, 2) + + // SetCookie will store the encoded session in the cookie + require.Equal(t, "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; HttpOnly; Secure", setCookies[0]) + // ClearCookie will add an entry with the cookie value cleared out + require.Equal(t, "__Host-session=; Path=/; HttpOnly; Secure", setCookies[1]) +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index 7ff575a5e2fbc..5f5be3bac7014 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -892,7 +892,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { logout := app.Command("logout", "Delete a cluster certificate.") // bench - bench := app.Command("bench", "Run shell or execute a command on a remote SSH node.").Hidden() + bench := app.Command("bench", "Run Teleport benchmark tests.").Hidden() bench.Flag("cluster", clusterHelp).Short('c').StringVar(&cf.SiteName) bench.Flag("duration", "Test duration").Default("1s").DurationVar(&cf.BenchDuration) bench.Flag("rate", "Requests per second rate").Default("10").IntVar(&cf.BenchRate) @@ -901,14 +901,21 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { bench.Flag("ticks", "Ticks per half distance").Default("100").Int32Var(&cf.BenchTicks) bench.Flag("scale", "Value scale in which to scale the recorded values").Default("1.0").Float64Var(&cf.BenchValueScale) - benchSSH := bench.Command("ssh", "Run SSH benchmark test").Hidden() + benchSSH := bench.Command("ssh", "Run SSH benchmark tests").Hidden() benchSSH.Arg("[user@]host", "Remote hostname and the login to use").Required().StringVar(&cf.UserHost) benchSSH.Arg("command", "Command to execute on a remote host").Required().StringsVar(&cf.RemoteCommand) benchSSH.Flag("port", "SSH port on a remote host").Short('p').Int32Var(&cf.NodePort) - benchSSH.Flag("interactive", "Create interactive SSH session").BoolVar(&cf.BenchInteractive) benchSSH.Flag("random", "Connect to random hosts for each SSH session. The provided hostname must be all: tsh bench ssh --random @all ").BoolVar(&cf.BenchRandom) + + benchWeb := bench.Command("web", "Run Web benchmark tests").Hidden() + benchWebSSH := benchWeb.Command("ssh", "Run SSH benchmark tests").Hidden() + benchWebSSH.Arg("[user@]host", "Remote hostname and the login to use").Required().StringVar(&cf.UserHost) + benchWebSSH.Arg("command", "Command to execute on a remote host").Required().StringsVar(&cf.RemoteCommand) + benchWebSSH.Flag("port", "SSH port on a remote host").Short('p').Int32Var(&cf.NodePort) + benchWebSSH.Flag("random", "Connect to random hosts for each SSH session. The provided hostname must be all: tsh bench ssh --random @all ").BoolVar(&cf.BenchRandom) + var benchKubeOpts benchKubeOptions - benchKube := bench.Command("kube", "Run Kube benchmark test").Hidden() + benchKube := bench.Command("kube", "Run Kube benchmark tests").Hidden() benchKube.Flag("kube-namespace", "Selects the ").Default("default").StringVar(&benchKubeOpts.namespace) benchListKube := benchKube.Command("ls", "Run a benchmark test to list Pods").Hidden() benchListKube.Arg("kube_cluster", "Kubernetes cluster to use").Required().StringVar(&cf.KubernetesCluster) @@ -1188,6 +1195,15 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { Random: cf.BenchRandom, }, ) + case benchWebSSH.FullCommand(): + err = onBenchmark( + &cf, + &benchmark.WebSSHBenchmark{ + Command: cf.RemoteCommand, + Random: cf.BenchRandom, + Duration: cf.BenchDuration, + }, + ) case benchListKube.FullCommand(): err = onBenchmark( &cf,