diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 9cfb4a81175ec..2d9c19944d824 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -77,8 +77,8 @@ import ( ) const ( - // ssoLoginConsoleErr is a generic error message to hide revealing sso login failure msgs. - ssoLoginConsoleErr = "Failed to login. Please check Teleport's log for more details." + // SSOLoginConsoleErr is a generic error message to hide revealing sso login failure msgs. + SSOLoginConsoleErr = "Failed to login. Please check Teleport's log for more details." metaRedirectHTML = ` @@ -184,10 +184,10 @@ type Config struct { // Enables web UI if set. StaticFS http.FileSystem - // cachedSessionLingeringThreshold specifies the time the session will linger + // CachedSessionLingeringThreshold specifies the time the session will linger // in the cache before getting purged after it has expired. // Defaults to cachedSessionLingeringThreshold if unspecified. - cachedSessionLingeringThreshold *time.Duration + CachedSessionLingeringThreshold *time.Duration // ClusterFeatures contains flags for supported/unsupported features. ClusterFeatures proto.Features @@ -266,8 +266,8 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { } sessionLingeringThreshold := cachedSessionLingeringThreshold - if cfg.cachedSessionLingeringThreshold != nil { - sessionLingeringThreshold = *cfg.cachedSessionLingeringThreshold + if cfg.CachedSessionLingeringThreshold != nil { + sessionLingeringThreshold = *cfg.CachedSessionLingeringThreshold } auth, err := newSessionCache(sessionCacheOptions{ @@ -1211,17 +1211,17 @@ func (h *Handler) oidcLoginWeb(w http.ResponseWriter, r *http.Request, p httprou logger := h.log.WithField("auth", "oidc") logger.Debug("Web login start.") - req, err := parseSSORequestParams(r) + req, err := ParseSSORequestParams(r) if err != nil { logger.WithError(err).Error("Failed to extract SSO parameters from request.") return client.LoginFailedRedirectURL } response, err := h.cfg.ProxyClient.CreateOIDCAuthRequest(r.Context(), types.OIDCAuthRequest{ - CSRFToken: req.csrfToken, - ConnectorID: req.connectorID, + CSRFToken: req.CSRFToken, + ConnectorID: req.ConnectorID, CreateWebSession: true, - ClientRedirectURL: req.clientRedirectURL, + ClientRedirectURL: req.ClientRedirectURL, CheckUser: true, ProxyAddress: r.Host, }) @@ -1237,17 +1237,17 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr logger := h.log.WithField("auth", "github") logger.Debug("Web login start.") - req, err := parseSSORequestParams(r) + req, err := ParseSSORequestParams(r) if err != nil { logger.WithError(err).Error("Failed to extract SSO parameters from request.") return client.LoginFailedRedirectURL } response, err := h.cfg.ProxyClient.CreateGithubAuthRequest(r.Context(), types.GithubAuthRequest{ - CSRFToken: req.csrfToken, - ConnectorID: req.connectorID, + CSRFToken: req.CSRFToken, + ConnectorID: req.ConnectorID, CreateWebSession: true, - ClientRedirectURL: req.clientRedirectURL, + ClientRedirectURL: req.ClientRedirectURL, }) if err != nil { logger.WithError(err).Error("Error creating auth request.") @@ -1265,12 +1265,12 @@ func (h *Handler) githubLoginConsole(w http.ResponseWriter, r *http.Request, p h req := new(client.SSOLoginConsoleReq) if err := httplib.ReadJSON(r, req); err != nil { logger.WithError(err).Error("Error reading json.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } if err := req.CheckAndSetDefaults(); err != nil { logger.WithError(err).Error("Missing request parameters.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } response, err := h.cfg.ProxyClient.CreateGithubAuthRequest(r.Context(), types.GithubAuthRequest{ @@ -1285,7 +1285,7 @@ func (h *Handler) githubLoginConsole(w http.ResponseWriter, r *http.Request, p h }) if err != nil { logger.WithError(err).Error("Failed to create Github auth request.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } return &client.SSOLoginConsoleResponse{ @@ -1307,7 +1307,7 @@ func (h *Handler) githubCallback(w http.ResponseWriter, r *http.Request, p httpr // this improves the UX by terminating the failed SSO flow immediately, rather than hoping for a timeout. if requestID := r.URL.Query().Get("state"); requestID != "" { if request, errGet := h.cfg.ProxyClient.GetGithubAuthRequest(r.Context(), requestID); errGet == nil && !request.CreateWebSession { - if redURL, errEnc := redirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { + if redURL, errEnc := RedirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { return redURL.String() } } @@ -1323,19 +1323,19 @@ func (h *Handler) githubCallback(w http.ResponseWriter, r *http.Request, p httpr if response.Req.CreateWebSession { logger.Infof("Redirecting to web browser.") - res := &ssoCallbackResponse{ - csrfToken: response.Req.CSRFToken, - username: response.Username, - sessionName: response.Session.GetName(), - clientRedirectURL: response.Req.ClientRedirectURL, + res := &SSOCallbackResponse{ + CSRFToken: response.Req.CSRFToken, + Username: response.Username, + SessionName: response.Session.GetName(), + ClientRedirectURL: response.Req.ClientRedirectURL, } - if err := ssoSetWebSessionAndRedirectURL(w, r, res, true); err != nil { + if err := SSOSetWebSessionAndRedirectURL(w, r, res, true); err != nil { logger.WithError(err).Error("Error setting web session.") return client.LoginFailedRedirectURL } - return res.clientRedirectURL + return res.ClientRedirectURL } logger.Infof("Callback is redirecting to console login.") @@ -1369,12 +1369,12 @@ func (h *Handler) oidcLoginConsole(w http.ResponseWriter, r *http.Request, p htt req := new(client.SSOLoginConsoleReq) if err := httplib.ReadJSON(r, req); err != nil { logger.WithError(err).Error("Error reading json.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } if err := req.CheckAndSetDefaults(); err != nil { logger.WithError(err).Error("Missing request parameters.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } response, err := h.cfg.ProxyClient.CreateOIDCAuthRequest(r.Context(), types.OIDCAuthRequest{ @@ -1391,7 +1391,7 @@ func (h *Handler) oidcLoginConsole(w http.ResponseWriter, r *http.Request, p htt }) if err != nil { logger.WithError(err).Error("Failed to create OIDC auth request.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } return &client.SSOLoginConsoleResponse{ @@ -1413,7 +1413,7 @@ func (h *Handler) oidcCallback(w http.ResponseWriter, r *http.Request, p httprou // this improves the UX by terminating the failed SSO flow immediately, rather than hoping for a timeout. if requestID := r.URL.Query().Get("state"); requestID != "" { if request, errGet := h.cfg.ProxyClient.GetOIDCAuthRequest(r.Context(), requestID); errGet == nil && !request.CreateWebSession { - if redURL, errEnc := redirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { + if redURL, errEnc := RedirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { return redURL.String() } } @@ -1430,19 +1430,19 @@ func (h *Handler) oidcCallback(w http.ResponseWriter, r *http.Request, p httprou if response.Req.CreateWebSession { logger.Info("Redirecting to web browser.") - res := &ssoCallbackResponse{ - csrfToken: response.Req.CSRFToken, - username: response.Username, - sessionName: response.Session.GetName(), - clientRedirectURL: response.Req.ClientRedirectURL, + res := &SSOCallbackResponse{ + CSRFToken: response.Req.CSRFToken, + Username: response.Username, + SessionName: response.Session.GetName(), + ClientRedirectURL: response.Req.ClientRedirectURL, } - if err := ssoSetWebSessionAndRedirectURL(w, r, res, true); err != nil { + if err := SSOSetWebSessionAndRedirectURL(w, r, res, true); err != nil { logger.WithError(err).Error("Error setting web session.") return client.LoginFailedRedirectURL } - return res.clientRedirectURL + return res.ClientRedirectURL } logger.Info("Callback redirecting to console login.") @@ -1593,7 +1593,7 @@ func ConstructSSHResponse(response AuthParams) (*url.URL, error) { return u, nil } -func redirectURLWithError(clientRedirectURL string, errReply error) (*url.URL, error) { +func RedirectURLWithError(clientRedirectURL string, errReply error) (*url.URL, error) { u, err := url.Parse(clientRedirectURL) if err != nil { return nil, trace.Wrap(err) @@ -3167,13 +3167,13 @@ func makeTeleportClientConfig(ctx context.Context, sesCtx *SessionContext) (*cli return config, nil } -type ssoRequestParams struct { - clientRedirectURL string - connectorID string - csrfToken string +type SSORequestParams struct { + ClientRedirectURL string + ConnectorID string + CSRFToken string } -func parseSSORequestParams(r *http.Request) (*ssoRequestParams, error) { +func ParseSSORequestParams(r *http.Request) (*SSORequestParams, error) { // Manually grab the value from query param "redirect_url". // // The "redirect_url" param can contain its own query params such as in @@ -3205,37 +3205,37 @@ func parseSSORequestParams(r *http.Request) (*ssoRequestParams, error) { return nil, trace.Wrap(err) } - return &ssoRequestParams{ - clientRedirectURL: clientRedirectURL, - connectorID: connectorID, - csrfToken: csrfToken, + return &SSORequestParams{ + ClientRedirectURL: clientRedirectURL, + ConnectorID: connectorID, + CSRFToken: csrfToken, }, nil } -type ssoCallbackResponse struct { - csrfToken string - username string - sessionName string - clientRedirectURL string +type SSOCallbackResponse struct { + CSRFToken string + Username string + SessionName string + ClientRedirectURL string } -func ssoSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, response *ssoCallbackResponse, verifyCSRF bool) error { +func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, response *SSOCallbackResponse, verifyCSRF bool) error { if verifyCSRF { - if err := csrf.VerifyToken(response.csrfToken, r); err != nil { + if err := csrf.VerifyToken(response.CSRFToken, r); err != nil { return trace.Wrap(err) } } - if err := SetSessionCookie(w, response.username, response.sessionName); err != nil { + if err := SetSessionCookie(w, response.Username, response.SessionName); err != nil { return trace.Wrap(err) } - parsedURL, err := url.Parse(response.clientRedirectURL) + parsedURL, err := url.Parse(response.ClientRedirectURL) if err != nil { return trace.Wrap(err) } - response.clientRedirectURL = parsedURL.RequestURI() + response.ClientRedirectURL = parsedURL.RequestURI() return nil } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 3e14d583c25d3..6c4e5aec8eda2 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -63,7 +63,6 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/auth/mocku2f" "github.com/gravitational/teleport/lib/auth/native" "github.com/gravitational/teleport/lib/auth/testauthority" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" @@ -114,26 +113,6 @@ import ( clientcmdapi "k8s.io/client-go/tools/clientcmd/api" ) -const hostID = "00000000-0000-0000-0000-000000000000" - -type WebSuite struct { - ctx context.Context - cancel context.CancelFunc - - node *regular.Server - proxy *regular.Server - proxyTunnel reversetunnel.Server - srvID string - - user string - webServer *httptest.Server - - mockU2F *mocku2f.Key - server *auth.TestServer - proxyClient *auth.Client - clock clockwork.FakeClock -} - // TestMain will re-execute Teleport to run a command if "exec" is passed to // it as an argument. Otherwise it will run tests as normal. func TestMain(m *testing.M) { @@ -150,381 +129,6 @@ func TestMain(m *testing.M) { os.Exit(code) } -func newWebSuite(t *testing.T) *WebSuite { - mockU2F, err := mocku2f.Create() - require.NoError(t, err) - require.NotNil(t, mockU2F) - - u, err := user.Current() - require.NoError(t, err) - - ctx, cancel := context.WithCancel(context.Background()) - s := &WebSuite{ - mockU2F: mockU2F, - clock: clockwork.NewFakeClock(), - user: u.Username, - ctx: ctx, - cancel: cancel, - } - - networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ - KeepAliveInterval: types.Duration(10 * time.Second), - }) - require.NoError(t, err) - - s.server, err = auth.NewTestServer(auth.TestServerConfig{ - Auth: auth.TestAuthServerConfig{ - ClusterName: "localhost", - Dir: t.TempDir(), - Clock: s.clock, - ClusterNetworkingConfig: networkingConfig, - }, - }) - require.NoError(t, err) - - // Register the auth server, since test auth server doesn't start its own - // heartbeat. - err = s.server.Auth().UpsertAuthServer(&types.ServerV2{ - Kind: types.KindAuthServer, - Version: types.V2, - Metadata: types.Metadata{ - Namespace: apidefaults.Namespace, - Name: "auth", - }, - Spec: types.ServerSpecV2{ - Addr: s.server.TLS.Listener.Addr().String(), - Hostname: "localhost", - Version: teleport.Version, - }, - }) - require.NoError(t, err) - - priv, pub, err := testauthority.New().GenerateKeyPair() - require.NoError(t, err) - - tlsPub, err := auth.PrivateKeyToPublicKeyTLS(priv) - require.NoError(t, err) - - // start node - certs, err := s.server.Auth().GenerateHostCerts(s.ctx, - &authproto.HostCertsRequest{ - HostID: hostID, - NodeName: s.server.ClusterName(), - Role: types.RoleNode, - PublicSSHKey: pub, - PublicTLSKey: tlsPub, - }) - require.NoError(t, err) - - signer, err := sshutils.NewSigner(priv, certs.SSH) - require.NoError(t, err) - - nodeID := "node" - nodeClient, err := s.server.NewClient(auth.TestIdentity{ - I: auth.BuiltinRole{ - Role: types.RoleNode, - Username: nodeID, - }, - }) - require.NoError(t, err) - - nodeLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentNode, - Client: nodeClient, - }, - }) - require.NoError(t, err) - - // create SSH service: - nodeDataDir := t.TempDir() - node, err := regular.New( - utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, - s.server.ClusterName(), - []ssh.Signer{signer}, - nodeClient, - nodeDataDir, - "", - utils.NetAddr{}, - nodeClient, - regular.SetUUID(nodeID), - regular.SetNamespace(apidefaults.Namespace), - regular.SetShell("/bin/sh"), - regular.SetEmitter(nodeClient), - regular.SetPAMConfig(&pam.Config{Enabled: false}), - regular.SetBPF(&bpf.NOP{}), - regular.SetRestrictedSessionManager(&restricted.NOP{}), - regular.SetClock(s.clock), - regular.SetLockWatcher(nodeLockWatcher), - ) - require.NoError(t, err) - s.node = node - s.srvID = node.ID() - require.NoError(t, s.node.Start()) - require.NoError(t, auth.CreateUploaderDir(nodeDataDir)) - - // create reverse tunnel service: - proxyID := "proxy" - s.proxyClient, err = s.server.NewClient(auth.TestIdentity{ - I: auth.BuiltinRole{ - Role: types.RoleProxy, - Username: proxyID, - }, - }) - require.NoError(t, err) - - revTunListener, err := net.Listen("tcp", fmt.Sprintf("%v:0", s.server.ClusterName())) - require.NoError(t, err) - - proxyLockWatcher, err := services.NewLockWatcher(s.ctx, services.LockWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Client: s.proxyClient, - }, - }) - require.NoError(t, err) - - proxyNodeWatcher, err := services.NewNodeWatcher(s.ctx, services.NodeWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Client: s.proxyClient, - }, - }) - require.NoError(t, err) - - caWatcher, err := services.NewCertAuthorityWatcher(s.ctx, services.CertAuthorityWatcherConfig{ - ResourceWatcherConfig: services.ResourceWatcherConfig{ - Component: teleport.ComponentProxy, - Client: s.proxyClient, - }, - Types: []types.CertAuthType{types.HostCA, types.UserCA}, - }) - require.NoError(t, err) - defer caWatcher.Close() - - revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ - ID: node.ID(), - Listener: revTunListener, - ClientTLS: s.proxyClient.TLSConfig(), - ClusterName: s.server.ClusterName(), - HostSigners: []ssh.Signer{signer}, - LocalAuthClient: s.proxyClient, - LocalAccessPoint: s.proxyClient, - Emitter: s.proxyClient, - NewCachingAccessPoint: noCache, - DataDir: t.TempDir(), - LockWatcher: proxyLockWatcher, - NodeWatcher: proxyNodeWatcher, - CertAuthorityWatcher: caWatcher, - CircuitBreakerConfig: breaker.NoopBreakerConfig(), - LocalAuthAddresses: []string{s.server.TLS.Listener.Addr().String()}, - }) - require.NoError(t, err) - s.proxyTunnel = revTunServer - - // proxy server: - s.proxy, err = regular.New( - utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, - s.server.ClusterName(), - []ssh.Signer{signer}, - s.proxyClient, - t.TempDir(), - "", - utils.NetAddr{}, - s.proxyClient, - regular.SetUUID(proxyID), - regular.SetProxyMode("", revTunServer, s.proxyClient), - regular.SetEmitter(s.proxyClient), - regular.SetNamespace(apidefaults.Namespace), - regular.SetBPF(&bpf.NOP{}), - regular.SetRestrictedSessionManager(&restricted.NOP{}), - regular.SetClock(s.clock), - regular.SetLockWatcher(proxyLockWatcher), - regular.SetNodeWatcher(proxyNodeWatcher), - ) - require.NoError(t, err) - - // Expired sessions are purged immediately - var sessionLingeringThreshold time.Duration - fs, err := NewDebugFileSystem("../../webassets/teleport") - require.NoError(t, err) - handler, err := NewHandler(Config{ - Proxy: revTunServer, - AuthServers: utils.FromAddr(s.server.TLS.Addr()), - DomainName: s.server.ClusterName(), - ProxyClient: s.proxyClient, - CipherSuites: utils.DefaultCipherSuites(), - AccessPoint: s.proxyClient, - Context: s.ctx, - HostUUID: proxyID, - Emitter: s.proxyClient, - StaticFS: fs, - cachedSessionLingeringThreshold: &sessionLingeringThreshold, - ProxySettings: &mockProxySettings{}, - }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(s.clock)) - require.NoError(t, err) - - s.webServer = httptest.NewUnstartedServer(handler) - s.webServer.StartTLS() - err = s.proxy.Start() - require.NoError(t, err) - - // Wait for proxy to fully register before starting the test. - for start := time.Now(); ; { - proxies, err := s.proxyClient.GetProxies() - require.NoError(t, err) - if len(proxies) != 0 { - break - } - if time.Since(start) > 5*time.Second { - t.Fatal("proxy didn't register within 5s after startup") - } - } - - proxyAddr := utils.MustParseAddr(s.proxy.Addr()) - - addr := utils.MustParseAddr(s.webServer.Listener.Addr().String()) - handler.handler.cfg.ProxyWebAddr = *addr - handler.handler.cfg.ProxySSHAddr = *proxyAddr - _, sshPort, err := net.SplitHostPort(proxyAddr.String()) - require.NoError(t, err) - handler.handler.sshPort = sshPort - - t.Cleanup(func() { - // In particular close the lock watchers by canceling the context. - s.cancel() - - s.webServer.Close() - - var errors []error - if err := s.proxyTunnel.Close(); err != nil { - errors = append(errors, err) - } - if err := s.node.Close(); err != nil { - errors = append(errors, err) - } - s.webServer.Close() - if err := s.proxy.Close(); err != nil { - errors = append(errors, err) - } - if err := s.server.Shutdown(context.Background()); err != nil { - errors = append(errors, err) - } - require.Empty(t, errors) - }) - - return s -} - -func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) { - return clt, nil -} - -func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response { - resp, err := r.clt.PostJSON(ctx, r.clt.Endpoint("webapi", "sessions", "renew"), nil) - require.NoError(t, err) - return resp -} - -func (r *authPack) validateAPI(ctx context.Context, t *testing.T) { - _, err := r.clt.Get(ctx, r.clt.Endpoint("webapi", "sites"), url.Values{}) - require.NoError(t, err) -} - -type authPack struct { - otpSecret string - user string - login string - password string - session *CreateSessionResponse - clt *client.WebClient - cookies []*http.Cookie -} - -// authPack returns new authenticated package consisting of created valid -// user, otp token, created web session and authenticated client. -func (s *WebSuite) authPack(t *testing.T, user string) *authPack { - login := s.user - pass := "abc123" - rawSecret := "def456" - otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) - - ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ - Type: constants.Local, - SecondFactor: constants.SecondFactorOTP, - }) - require.NoError(t, err) - err = s.server.Auth().SetAuthPreference(s.ctx, ap) - require.NoError(t, err) - - s.createUser(t, user, login, pass, otpSecret) - - // create a valid otp token - validToken, err := totp.GenerateCode(otpSecret, s.clock.Now()) - require.NoError(t, err) - - clt := s.client() - req := CreateSessionReq{ - User: user, - Pass: pass, - SecondFactorToken: validToken, - } - - csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" - re, err := s.login(clt, csrfToken, csrfToken, req) - require.NoError(t, err) - - var rawSess *CreateSessionResponse - require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) - - sess, err := rawSess.response() - require.NoError(t, err) - - jar, err := cookiejar.New(nil) - require.NoError(t, err) - - clt = s.client(roundtrip.BearerAuth(sess.Token), roundtrip.CookieJar(jar)) - jar.SetCookies(s.url(), re.Cookies()) - - return &authPack{ - otpSecret: otpSecret, - user: user, - login: login, - session: sess, - clt: clt, - cookies: re.Cookies(), - } -} - -func (s *WebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string) { - teleUser, err := types.NewUser(user) - require.NoError(t, err) - role := services.RoleForUser(teleUser) - role.SetLogins(types.Allow, []string{login}) - options := role.GetOptions() - options.ForwardAgent = types.NewBool(true) - role.SetOptions(options) - err = s.server.Auth().UpsertRole(s.ctx, role) - require.NoError(t, err) - teleUser.AddRole(role.GetName()) - - teleUser.SetCreatedBy(types.CreatedBy{ - User: types.UserRef{Name: "some-auth-user"}, - }) - err = s.server.Auth().CreateUser(s.ctx, teleUser) - require.NoError(t, err) - - err = s.server.Auth().UpsertPassword(user, []byte(pass)) - require.NoError(t, err) - - if otpSecret != "" { - dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) - require.NoError(t, err) - err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) - require.NoError(t, err) - } -} - func TestValidRedirectURL(t *testing.T) { t.Parallel() for _, tt := range []struct { @@ -599,7 +203,7 @@ func TestSAML(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() - s := newWebSuite(t) + s := NewTestWebSuite(t) input := tc.rawConnector decoder := kyaml.NewYAMLOrJSONDecoder(strings.NewReader(input), defaults.LookaheadBufSize) @@ -623,14 +227,14 @@ func TestSAML(t *testing.T) { }, }) require.NoError(t, err) - role.SetLogins(types.Allow, []string{s.user}) - err = s.server.Auth().UpsertRole(s.ctx, role) + role.SetLogins(types.Allow, []string{s.User}) + err = s.Server.Auth().UpsertRole(s.Ctx, role) require.NoError(t, err) - err = s.server.Auth().UpsertSAMLConnector(ctx, connector) + err = s.Server.Auth().UpsertSAMLConnector(ctx, connector) require.NoError(t, err) - s.server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 5, 10, 18, 53, 0, 0, time.UTC))) - clt := s.clientNoRedirects() + s.Server.Auth().SetClock(clockwork.NewFakeClockAt(time.Date(2017, 5, 10, 18, 53, 0, 0, time.UTC))) + clt := s.ClientNoRedirects() csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" @@ -638,7 +242,7 @@ func TestSAML(t *testing.T) { require.NoError(t, err) req, err := http.NewRequest("GET", baseURL.String(), nil) require.NoError(t, err) - addCSRFCookieToReq(req, csrfToken) + AddCSRFCookieToReq(req, csrfToken) re, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) }) @@ -660,13 +264,13 @@ func TestSAML(t *testing.T) { id := doc.Root().SelectAttr("ID") require.NotNil(t, id) - authRequest, err := s.server.Auth().GetSAMLAuthRequest(context.Background(), id.Value) + authRequest, err := s.Server.Auth().GetSAMLAuthRequest(context.Background(), id.Value) require.NoError(t, err) // now swap the request id to the hardcoded one in fixtures authRequest.ID = fixtures.SAMLOktaAuthRequestID authRequest.CSRFToken = csrfToken - err = s.server.Auth().Services.CreateSAMLAuthRequest(ctx, *authRequest, backend.Forever) + err = s.Server.Auth().Services.CreateSAMLAuthRequest(ctx, *authRequest, backend.Forever) require.NoError(t, err) // now respond with pre-recorded request to the POST url @@ -686,7 +290,7 @@ func TestSAML(t *testing.T) { form.Add("SAMLResponse", encodedResponse) req, err = http.NewRequest("POST", clt.Endpoint("webapi", "saml", "acs"), strings.NewReader(form.Encode())) req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - addCSRFCookieToReq(req, csrfToken) + AddCSRFCookieToReq(req, csrfToken) require.NoError(t, err) authRe, err := clt.Client.RoundTrip(func() (*http.Response, error) { return clt.Client.HTTPClient().Do(req) @@ -706,7 +310,7 @@ func TestSAML(t *testing.T) { func TestWebSessionsCRUD(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) pack := s.authPack(t, "foo") // make sure we can use client to make authenticated requests @@ -730,7 +334,7 @@ func TestWebSessionsCRUD(t *testing.T) { func TestCSRF(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) type input struct { reqToken string cookieToken string @@ -776,12 +380,12 @@ func TestCSRF(t *testing.T) { func TestPasswordChange(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) pack := s.authPack(t, "foo") // invalidate the token - s.clock.Advance(1 * time.Minute) - validToken, err := totp.GenerateCode(pack.otpSecret, s.clock.Now()) + s.Clock.Advance(1 * time.Minute) + validToken, err := totp.GenerateCode(pack.otpSecret, s.Clock.Now()) require.NoError(t, err) req := changePasswordReq{ @@ -818,18 +422,18 @@ func TestValidateBearerToken(t *testing.T) { func TestWebSessionsBadInput(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) user := "bob" pass := "abc123" rawSecret := "def456" otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) - err := s.server.Auth().UpsertPassword(user, []byte(pass)) + err := s.Server.Auth().UpsertPassword(user, []byte(pass)) require.NoError(t, err) - dev, err := services.NewTOTPDevice("otp", otpSecret, s.clock.Now()) + dev, err := services.NewTOTPDevice("otp", otpSecret, s.Clock.Now()) require.NoError(t, err) - err = s.server.Auth().UpsertMFADevice(context.Background(), user, dev) + err = s.Server.Auth().UpsertMFADevice(context.Background(), user, dev) require.NoError(t, err) // create valid token @@ -871,7 +475,7 @@ func TestWebSessionsBadInput(t *testing.T) { } for i, req := range reqs { t.Run(fmt.Sprintf("tc %v", i), func(t *testing.T) { - _, err := clt.PostJSON(s.ctx, clt.Endpoint("webapi", "sessions"), req) + _, err := clt.PostJSON(s.Ctx, clt.Endpoint("webapi", "sessions"), req) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) }) @@ -950,7 +554,7 @@ func TestClusterAlertsGet(t *testing.T) { func TestSiteNodeConnectInvalidSessionID(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) _, err := s.makeTerminal(t, s.authPack(t, "foo"), withSessionID(session.ID("/../../../foo"))) require.Error(t, err) } @@ -1135,7 +739,7 @@ func TestNewTerminalHandler(t *testing.T) { func TestResizeTerminal(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) sid := session.NewID() // Create a new user "foo", open a terminal to a new session, and wait for @@ -1210,7 +814,7 @@ func TestResizeTerminal(t *testing.T) { // TestTerminalPing tests that the server sends continuous ping control messages. func TestTerminalPing(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) ws, err := s.makeTerminal(t, s.authPack(t, "foo"), withKeepaliveInterval(500*time.Millisecond)) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -1252,7 +856,7 @@ func TestTerminalPing(t *testing.T) { func TestTerminal(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) ws, err := s.makeTerminal(t, s.authPack(t, "foo")) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -1526,7 +1130,7 @@ func handleMFAWebauthnChallenge(t *testing.T, ws *websocket.Conn, dev *auth.Test func TestWebAgentForward(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) ws, err := s.makeTerminal(t, s.authPack(t, "foo")) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) }) @@ -1543,7 +1147,7 @@ func TestWebAgentForward(t *testing.T) { func TestActiveSessions(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) sid := session.NewID() pack := s.authPack(t, "foo") @@ -1565,7 +1169,7 @@ func TestActiveSessions(t *testing.T) { var sessResp *siteSessionsGetResponse require.Eventually(t, func() bool { // Get site nodes and make sure the node has our active party. - re, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "sessions"), url.Values{}) + re, err := pack.clt.Get(s.Ctx, pack.clt.Endpoint("webapi", "sites", s.Server.ClusterName(), "sessions"), url.Values{}) require.NoError(t, err) require.NoError(t, json.Unmarshal(re.Bytes(), &sessResp)) return len(sessResp.Sessions) == 1 @@ -1573,22 +1177,22 @@ func TestActiveSessions(t *testing.T) { sess := sessResp.Sessions[0] require.Equal(t, sid, sess.ID) - require.Equal(t, s.node.GetNamespace(), sess.Namespace) + require.Equal(t, s.Node.GetNamespace(), sess.Namespace) require.NotNil(t, sess.Parties) require.Greater(t, sess.TerminalParams.H, 0) require.Greater(t, sess.TerminalParams.W, 0) require.Equal(t, pack.login, sess.Login) require.False(t, sess.Created.IsZero()) require.False(t, sess.LastActive.IsZero()) - require.Equal(t, s.srvID, sess.ServerID) - require.Equal(t, s.node.GetInfo().GetHostname(), sess.ServerHostname) - require.Equal(t, s.srvID, sess.ServerAddr) - require.Equal(t, s.server.ClusterName(), sess.ClusterName) + require.Equal(t, s.SrvID, sess.ServerID) + require.Equal(t, s.Node.GetInfo().GetHostname(), sess.ServerHostname) + require.Equal(t, s.SrvID, sess.ServerAddr) + require.Equal(t, s.Server.ClusterName(), sess.ClusterName) } func TestCloseConnectionsOnLogout(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) sid := session.NewID() pack := s.authPack(t, "foo") @@ -1608,7 +1212,7 @@ func TestCloseConnectionsOnLogout(t *testing.T) { _, err = stream.Read(out) require.NoError(t, err) - _, err = pack.clt.Delete(s.ctx, pack.clt.Endpoint("webapi", "sessions")) + _, err = pack.clt.Delete(s.Ctx, pack.clt.Endpoint("webapi", "sessions")) require.NoError(t, err) // wait until we timeout or detect that connection has been closed @@ -1677,7 +1281,7 @@ func TestCreateSession(t *testing.T) { func TestPlayback(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) pack := s.authPack(t, "foo") sid := session.NewID() ws, err := s.makeTerminal(t, pack, withSessionID(sid)) @@ -1687,13 +1291,13 @@ func TestPlayback(t *testing.T) { func TestLogin(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ Type: constants.Local, SecondFactor: constants.SecondFactorOff, }) require.NoError(t, err) - err = s.server.Auth().SetAuthPreference(s.ctx, ap) + err = s.Server.Auth().SetAuthPreference(s.Ctx, ap) require.NoError(t, err) // create user @@ -1712,7 +1316,7 @@ func TestLogin(t *testing.T) { req.Header.Set("User-Agent", ua) csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" - addCSRFCookieToReq(req, csrfToken) + AddCSRFCookieToReq(req, csrfToken) req.Header.Set("Content-Type", "application/json") req.Header.Set(csrf.HeaderName, csrfToken) @@ -1721,9 +1325,9 @@ func TestLogin(t *testing.T) { }) require.NoError(t, err) - events, _, err := s.server.AuthServer.AuditLog.SearchEvents( - s.clock.Now().Add(-time.Hour), - s.clock.Now().Add(time.Hour), + events, _, err := s.Server.AuthServer.AuditLog.SearchEvents( + s.Clock.Now().Add(-time.Hour), + s.Clock.Now().Add(time.Hour), apidefaults.Namespace, []string{events.UserLoginEvent}, 1, @@ -1750,7 +1354,7 @@ func TestLogin(t *testing.T) { clt = s.client(roundtrip.BearerAuth(rawSess.Token), roundtrip.CookieJar(jar)) jar.SetCookies(s.url(), re.Cookies()) - re, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + re, err = clt.Get(s.Ctx, clt.Endpoint("webapi", "sites"), url.Values{}) require.NoError(t, err) var clusters []ui.Cluster @@ -1760,13 +1364,13 @@ func TestLogin(t *testing.T) { // no session cookie: clt = s.client(roundtrip.BearerAuth(rawSess.Token)) - _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + _, err = clt.Get(s.Ctx, clt.Endpoint("webapi", "sites"), url.Values{}) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) // no bearer token: clt = s.client(roundtrip.CookieJar(jar)) - _, err = clt.Get(s.ctx, clt.Endpoint("webapi", "sites"), url.Values{}) + _, err = clt.Get(s.Ctx, clt.Endpoint("webapi", "sites"), url.Values{}) require.Error(t, err) require.True(t, trace.IsAccessDenied(err)) } @@ -1775,14 +1379,14 @@ func TestLogin(t *testing.T) { // /webapi/motd work when no MotD is set func TestEmptyMotD(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) wc := s.client() // Given an auth server configured *not* to expose a Message Of The // Day... // When I issue a ping request... - re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + re, err := wc.Get(s.Ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) // Expect that the MotD flag in the ping response is *not* set @@ -1791,7 +1395,7 @@ func TestEmptyMotD(t *testing.T) { require.False(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + re, err = wc.Get(s.Ctx, wc.Endpoint("webapi", "motd"), url.Values{}) require.NoError(t, err) // Expect that an empty response returned @@ -1806,16 +1410,16 @@ func TestMotD(t *testing.T) { t.Parallel() const motd = "Hello. I'm a Teleport cluster!" - s := newWebSuite(t) + s := NewTestWebSuite(t) wc := s.client() // Given an auth server configured to expose a Message Of The Day... prefs := types.DefaultAuthPreference() prefs.SetMessageOfTheDay(motd) - require.NoError(t, s.server.AuthServer.AuthServer.SetAuthPreference(s.ctx, prefs)) + require.NoError(t, s.Server.AuthServer.AuthServer.SetAuthPreference(s.Ctx, prefs)) // When I issue a ping request... - re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + re, err := wc.Get(s.Ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) // Expect that the MotD flag in the ping response is set to indicate @@ -1825,7 +1429,7 @@ func TestMotD(t *testing.T) { require.True(t, pingResponse.Auth.HasMessageOfTheDay) // When I fetch the MotD... - re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "motd"), url.Values{}) + re, err = wc.Get(s.Ctx, wc.Endpoint("webapi", "motd"), url.Values{}) require.NoError(t, err) // Expect that the text returned is the configured value @@ -1836,7 +1440,7 @@ func TestMotD(t *testing.T) { func TestMultipleConnectors(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) wc := s.client() // create two oidc connectors, one named "foo" and another named "bar" @@ -1857,11 +1461,11 @@ func TestMultipleConnectors(t *testing.T) { } o, err := types.NewOIDCConnector("foo", oidcConnectorSpec) require.NoError(t, err) - err = s.server.Auth().UpsertOIDCConnector(s.ctx, o) + err = s.Server.Auth().UpsertOIDCConnector(s.Ctx, o) require.NoError(t, err) o2, err := types.NewOIDCConnector("bar", oidcConnectorSpec) require.NoError(t, err) - err = s.server.Auth().UpsertOIDCConnector(s.ctx, o2) + err = s.Server.Auth().UpsertOIDCConnector(s.Ctx, o2) require.NoError(t, err) // set the auth preferences to oidc with no connector name @@ -1869,18 +1473,18 @@ func TestMultipleConnectors(t *testing.T) { Type: "oidc", }) require.NoError(t, err) - err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) + err = s.Server.Auth().SetAuthPreference(s.Ctx, authPreference) require.NoError(t, err) // hit the ping endpoint to get the auth type and connector name - re, err := wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + re, err := wc.Get(s.Ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) var out *webclient.PingResponse require.NoError(t, json.Unmarshal(re.Bytes(), &out)) // make sure the connector name we got back was the first connector // in the backend, in this case it's "bar" - oidcConnectors, err := s.server.Auth().GetOIDCConnectors(s.ctx, false) + oidcConnectors, err := s.Server.Auth().GetOIDCConnectors(s.Ctx, false) require.NoError(t, err) require.Equal(t, oidcConnectors[0].GetName(), out.Auth.OIDC.Name) @@ -1890,11 +1494,11 @@ func TestMultipleConnectors(t *testing.T) { ConnectorName: "foo", }) require.NoError(t, err) - err = s.server.Auth().SetAuthPreference(s.ctx, authPreference) + err = s.Server.Auth().SetAuthPreference(s.Ctx, authPreference) require.NoError(t, err) // hit the ping endpoing to get the auth type and connector name - re, err = wc.Get(s.ctx, wc.Endpoint("webapi", "ping"), url.Values{}) + re, err = wc.Get(s.Ctx, wc.Endpoint("webapi", "ping"), url.Values{}) require.NoError(t, err) require.NoError(t, json.Unmarshal(re.Bytes(), &out)) @@ -2004,16 +1608,16 @@ func (f byTimeAndIndex) Swap(i, j int) { func TestSearchClusterEvents(t *testing.T) { t.Parallel() - s := newWebSuite(t) - clock := s.clock + s := NewTestWebSuite(t) + clock := s.Clock sessionEvents := events.GenerateTestSession(events.SessionParams{ PrintEvents: 3, Clock: clock, - ServerID: s.proxy.ID(), + ServerID: s.Proxy.ID(), }) for _, e := range sessionEvents { - require.NoError(t, s.proxyClient.EmitAuditEvent(s.ctx, e)) + require.NoError(t, s.ProxyClient.EmitAuditEvent(s.Ctx, e)) } sort.Sort(sort.Reverse(byTimeAndIndex(sessionEvents))) @@ -2103,7 +1707,7 @@ func TestSearchClusterEvents(t *testing.T) { tc := tc t.Run(tc.Comment, func(t *testing.T) { t.Parallel() - response, err := pack.clt.Get(s.ctx, pack.clt.Endpoint("webapi", "sites", s.server.ClusterName(), "events", "search"), tc.Query) + response, err := pack.clt.Get(s.Ctx, pack.clt.Endpoint("webapi", "sites", s.Server.ClusterName(), "events", "search"), tc.Query) require.NoError(t, err) var result eventsListGetResponse require.NoError(t, json.Unmarshal(response.Bytes(), &result)) @@ -2136,21 +1740,21 @@ func TestSearchClusterEvents(t *testing.T) { func TestGetClusterDetails(t *testing.T) { t.Parallel() - s := newWebSuite(t) - site, err := s.proxyTunnel.GetSite(s.server.ClusterName()) + s := NewTestWebSuite(t) + site, err := s.ProxyTunnel.GetSite(s.Server.ClusterName()) require.NoError(t, err) require.NotNil(t, site) - cluster, err := ui.GetClusterDetails(s.ctx, site) + cluster, err := ui.GetClusterDetails(s.Ctx, site) require.NoError(t, err) - require.Equal(t, s.server.ClusterName(), cluster.Name) + require.Equal(t, s.Server.ClusterName(), cluster.Name) require.Equal(t, teleport.Version, cluster.ProxyVersion) - require.Equal(t, fmt.Sprintf("%v:%v", s.server.ClusterName(), defaults.HTTPListenPort), cluster.PublicURL) + require.Equal(t, fmt.Sprintf("%v:%v", s.Server.ClusterName(), defaults.HTTPListenPort), cluster.PublicURL) require.Equal(t, teleport.RemoteClusterStatusOnline, cluster.Status) require.NotNil(t, cluster.LastConnected) require.Equal(t, teleport.Version, cluster.AuthVersion) - nodes, err := s.proxyClient.GetNodes(s.ctx, apidefaults.Namespace) + nodes, err := s.ProxyClient.GetNodes(s.Ctx, apidefaults.Namespace) require.NoError(t, err) require.Len(t, nodes, cluster.NodeCount) } @@ -3350,7 +2954,7 @@ func TestCreateRegisterChallenge(t *testing.T) { // be exchanged for an application specific session. func TestCreateAppSession(t *testing.T) { t.Parallel() - s := newWebSuite(t) + s := NewTestWebSuite(t) pack := s.authPack(t, "foo@example.com") // Register an application called "panel". @@ -3363,7 +2967,7 @@ func TestCreateAppSession(t *testing.T) { require.NoError(t, err) server, err := types.NewAppServerV3FromApp(app, "host", uuid.New().String()) require.NoError(t, err) - _, err = s.server.Auth().UpsertApplicationServer(s.ctx, server) + _, err = s.Server.Auth().UpsertApplicationServer(s.Ctx, server) require.NoError(t, err) // Extract the session ID and bearer token for the current session. @@ -3469,7 +3073,7 @@ func TestCreateAppSession(t *testing.T) { t.Parallel() // Make a request to create an application session for "panel". endpoint := pack.clt.Endpoint("webapi", "sessions", "app") - resp, err := pack.clt.PostJSON(s.ctx, endpoint, tt.inCreateRequest) + resp, err := pack.clt.PostJSON(s.Ctx, endpoint, tt.inCreateRequest) tt.outError(t, err) if err != nil { return @@ -3481,7 +3085,7 @@ func TestCreateAppSession(t *testing.T) { require.Equal(t, tt.outFQDN, response.FQDN) // Verify that the application session was created. - sess, err := s.server.Auth().GetAppSession(s.ctx, types.GetAppSessionRequest{ + sess, err := s.Server.Auth().GetAppSession(s.Ctx, types.GetAppSessionRequest{ SessionID: response.CookieValue, }) require.NoError(t, err) @@ -3712,33 +3316,33 @@ func TestParseSSORequestParams(t *testing.T) { tests := []struct { name, url string wantErr bool - expected *ssoRequestParams + expected *SSORequestParams }{ { name: "preserve redirect's query params (escaped)", url: "https://localhost/login?connector_id=oidc&redirect_url=https:%2F%2Flocalhost:8080%2Fweb%2Fcluster%2Fim-a-cluster-name%2Fnodes%3Fsearch=tunnel&sort=hostname:asc", - expected: &ssoRequestParams{ - clientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", - connectorID: "oidc", - csrfToken: token, + expected: &SSORequestParams{ + ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", + ConnectorID: "oidc", + CSRFToken: token, }, }, { name: "preserve redirect's query params (unescaped)", url: "https://localhost/login?connector_id=github&redirect_url=https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", - expected: &ssoRequestParams{ - clientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", - connectorID: "github", - csrfToken: token, + expected: &SSORequestParams{ + ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", + ConnectorID: "github", + CSRFToken: token, }, }, { name: "preserve various encoded chars", url: "https://localhost/login?connector_id=saml&redirect_url=https:%2F%2Flocalhost:8080%2Fweb%2Fcluster%2Fim-a-cluster-name%2Fapps%3Fquery=search(%2522watermelon%2522%252C%2520%2522this%2522)%2520%2526%2526%2520labels%255B%2522unique-id%2522%255D%2520%253D%253D%2520%2522hi%2522&sort=name:asc", - expected: &ssoRequestParams{ - clientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc", - connectorID: "saml", - csrfToken: token, + expected: &SSORequestParams{ + ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc", + ConnectorID: "saml", + CSRFToken: token, }, }, { @@ -3757,9 +3361,9 @@ func TestParseSSORequestParams(t *testing.T) { t.Run(tc.name, func(t *testing.T) { req, err := http.NewRequest("", tc.url, nil) require.NoError(t, err) - addCSRFCookieToReq(req, token) + AddCSRFCookieToReq(req, token) - params, err := parseSSORequestParams(req) + params, err := ParseSSORequestParams(req) switch { case tc.wantErr: @@ -4478,65 +4082,6 @@ func (mock authProviderMock) GetSessionTracker(ctx context.Context, sessionID st return nil, trace.NotFound("foo") } -type terminalOpt func(t *TerminalRequest) - -func withSessionID(sid session.ID) terminalOpt { - return func(t *TerminalRequest) { t.SessionID = sid } -} - -func withKeepaliveInterval(d time.Duration) terminalOpt { - return func(t *TerminalRequest) { t.KeepAliveInterval = d } -} - -func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOpt) (*websocket.Conn, error) { - req := TerminalRequest{ - Server: s.srvID, - Login: pack.login, - Term: session.TerminalParams{ - W: 100, - H: 100, - }, - SessionID: session.NewID(), - } - for _, opt := range opts { - opt(&req) - } - - u := url.URL{ - Host: s.url().Host, - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut), - } - data, err := json.Marshal(req) - if err != nil { - return nil, err - } - - q := u.Query() - q.Set("params", string(data)) - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) - u.RawQuery = q.Encode() - - dialer := websocket.Dialer{} - dialer.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } - - header := http.Header{} - header.Add("Origin", "http://localhost") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) - } - - ws, resp, err := dialer.Dial(u.String(), header) - if err != nil { - return nil, trace.Wrap(err) - } - - require.NoError(t, resp.Body.Close()) - return ws, nil -} - func waitForOutput(stream *terminalStream, substr string) error { timeoutCh := time.After(10 * time.Second) @@ -4558,203 +4103,6 @@ func waitForOutput(stream *terminalStream, substr string) error { } } -func (s *WebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) - defer timeoutCancel() - - done := make(chan error, 1) - - go func() { - for { - ty, raw, err := ws.ReadMessage() - if err != nil { - done <- trace.Wrap(err) - return - } - - if ty != websocket.BinaryMessage { - done <- trace.BadParameter("expected binary message, got %v", ty) - return - } - - var envelope Envelope - err = proto.Unmarshal(raw, &envelope) - if err != nil { - done <- trace.Wrap(err) - return - } - - if envelope.GetType() == defaults.WebsocketRaw { - done <- nil - return - } - } - }() - - for { - select { - case <-timeoutContext.Done(): - return trace.BadParameter("timeout waiting for raw event") - case err := <-done: - return trace.Wrap(err) - } - } -} - -func (s *WebSuite) waitForResizeEvent(ws *websocket.Conn, timeout time.Duration) error { - timeoutContext, timeoutCancel := context.WithTimeout(s.ctx, timeout) - defer timeoutCancel() - - done := make(chan error, 1) - - go func() { - for { - ty, raw, err := ws.ReadMessage() - if err != nil { - done <- trace.Wrap(err) - return - } - - if ty != websocket.BinaryMessage { - done <- trace.BadParameter("expected binary message, got %v", ty) - return - } - - var envelope Envelope - err = proto.Unmarshal(raw, &envelope) - if err != nil { - done <- trace.Wrap(err) - return - } - - if envelope.GetType() != defaults.WebsocketAudit { - continue - } - - var e events.EventFields - err = json.Unmarshal([]byte(envelope.GetPayload()), &e) - if err != nil { - done <- trace.Wrap(err) - return - } - - if e.GetType() == events.ResizeEvent { - done <- nil - return - } - } - }() - - for { - select { - case <-timeoutContext.Done(): - return trace.BadParameter("timeout waiting for resize event") - case err := <-done: - return trace.Wrap(err) - } - } -} - -func (s *WebSuite) listenForResizeEvent(ws *websocket.Conn) chan struct{} { - ch := make(chan struct{}) - - go func() { - for { - ty, raw, err := ws.ReadMessage() - if err != nil { - close(ch) - return - } - - if ty != websocket.BinaryMessage { - close(ch) - return - } - - var envelope Envelope - err = proto.Unmarshal(raw, &envelope) - if err != nil { - close(ch) - return - } - - if envelope.GetType() != defaults.WebsocketAudit { - continue - } - - var e events.EventFields - err = json.Unmarshal([]byte(envelope.GetPayload()), &e) - if err != nil { - close(ch) - return - } - - if e.GetType() == events.ResizeEvent { - ch <- struct{}{} - return - } - } - }() - - return ch -} - -func (s *WebSuite) clientNoRedirects(opts ...roundtrip.ClientParam) *client.WebClient { - hclient := client.NewInsecureWebClient() - hclient.CheckRedirect = func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - } - opts = append(opts, roundtrip.HTTPClient(hclient)) - wc, err := client.NewWebClient(s.url().String(), opts...) - if err != nil { - panic(err) - } - return wc -} - -func (s *WebSuite) client(opts ...roundtrip.ClientParam) *client.WebClient { - opts = append(opts, roundtrip.HTTPClient(client.NewInsecureWebClient())) - wc, err := client.NewWebClient(s.url().String(), opts...) - if err != nil { - panic(err) - } - return wc -} - -func (s *WebSuite) login(clt *client.WebClient, cookieToken string, reqToken string, reqData interface{}) (*roundtrip.Response, error) { - return httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { - data, err := json.Marshal(reqData) - if err != nil { - return nil, err - } - req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(data)) - if err != nil { - return nil, err - } - addCSRFCookieToReq(req, cookieToken) - req.Header.Set("Content-Type", "application/json") - req.Header.Set(csrf.HeaderName, reqToken) - return clt.HTTPClient().Do(req) - })) -} - -func (s *WebSuite) url() *url.URL { - u, err := url.Parse("https://" + s.webServer.Listener.Addr().String()) - if err != nil { - panic(err) - } - return u -} - -func addCSRFCookieToReq(req *http.Request, token string) { - cookie := &http.Cookie{ - Name: csrf.CookieName, - Value: token, - } - - req.AddCookie(cookie) -} - func removeSpace(in string) string { for _, c := range []string{"\n", "\r", "\t"} { in = strings.Replace(in, c, " ", -1) @@ -4782,10 +4130,6 @@ func decodeSessionCookie(t *testing.T, value string) (sessionID string) { return cookie.SessionID } -func (r CreateSessionResponse) response() (*CreateSessionResponse, error) { - return &CreateSessionResponse{TokenType: r.TokenType, Token: r.Token, TokenExpiresIn: r.TokenExpiresIn, SessionInactiveTimeoutMS: r.SessionInactiveTimeoutMS}, nil -} - func newWebPack(t *testing.T, numProxies int) *webPack { ctx := context.Background() clock := clockwork.NewFakeClockAt(time.Now()) @@ -5305,7 +4649,7 @@ func login(t *testing.T, clt *client.WebClient, cookieToken, reqToken string, re if err != nil { return nil, err } - addCSRFCookieToReq(req, cookieToken) + AddCSRFCookieToReq(req, cookieToken) req.Header.Set("Content-Type", "application/json") req.Header.Set(csrf.HeaderName, reqToken) return clt.HTTPClient().Do(req) @@ -5324,12 +4668,6 @@ func validateTerminalStream(t *testing.T, conn *websocket.Conn) { require.NoError(t, err) } -type mockProxySettings struct{} - -func (mock *mockProxySettings) GetProxySettings(ctx context.Context) (*webclient.ProxySettings, error) { - return &webclient.ProxySettings{}, nil -} - // TestUserContextWithAccessRequest checks that the userContext includes the ID of the // access request after it has been consumed and the web session has been renewed. func TestUserContextWithAccessRequest(t *testing.T) { diff --git a/lib/web/saml.go b/lib/web/saml.go index f16fb08ce3168..a18173c7f6957 100644 --- a/lib/web/saml.go +++ b/lib/web/saml.go @@ -35,17 +35,17 @@ func (h *Handler) samlSSO(w http.ResponseWriter, r *http.Request, p httprouter.P logger := h.log.WithField("auth", "saml") logger.Debug("Web login start.") - req, err := parseSSORequestParams(r) + req, err := ParseSSORequestParams(r) if err != nil { logger.WithError(err).Error("Failed to extract SSO parameters from request.") return client.LoginFailedRedirectURL } response, err := h.cfg.ProxyClient.CreateSAMLAuthRequest(r.Context(), types.SAMLAuthRequest{ - ConnectorID: req.connectorID, - CSRFToken: req.csrfToken, + ConnectorID: req.ConnectorID, + CSRFToken: req.CSRFToken, CreateWebSession: true, - ClientRedirectURL: req.clientRedirectURL, + ClientRedirectURL: req.ClientRedirectURL, }) if err != nil { logger.WithError(err).Error("Error creating auth request.") @@ -62,12 +62,12 @@ func (h *Handler) samlSSOConsole(w http.ResponseWriter, r *http.Request, p httpr req := new(client.SSOLoginConsoleReq) if err := httplib.ReadJSON(r, req); err != nil { logger.WithError(err).Error("Error reading json.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } if err := req.CheckAndSetDefaults(); err != nil { logger.WithError(err).Error("Missing request parameters.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } response, err := h.cfg.ProxyClient.CreateSAMLAuthRequest(r.Context(), types.SAMLAuthRequest{ @@ -82,7 +82,7 @@ func (h *Handler) samlSSOConsole(w http.ResponseWriter, r *http.Request, p httpr }) if err != nil { logger.WithError(err).Error("Failed to create SAML auth request.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginConsoleErr) } return &client.SSOLoginConsoleResponse{RedirectURL: response.RedirectURL}, nil @@ -109,7 +109,7 @@ func (h *Handler) samlACS(w http.ResponseWriter, r *http.Request, p httprouter.P // this improves the UX by terminating the failed SSO flow immediately, rather than hoping for a timeout. if requestID, errParse := auth.ParseSAMLInResponseTo(samlResponse); errParse == nil { if request, errGet := h.cfg.ProxyClient.GetSAMLAuthRequest(r.Context(), requestID); errGet == nil && !request.CreateWebSession { - if url, errEnc := redirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { + if url, errEnc := RedirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { return url.String() } } @@ -131,19 +131,19 @@ func (h *Handler) samlACS(w http.ResponseWriter, r *http.Request, p httprouter.P redirect = "/web/" } - res := &ssoCallbackResponse{ - csrfToken: response.Req.CSRFToken, - username: response.Username, - sessionName: response.Session.GetName(), - clientRedirectURL: redirect, + res := &SSOCallbackResponse{ + CSRFToken: response.Req.CSRFToken, + Username: response.Username, + SessionName: response.Session.GetName(), + ClientRedirectURL: redirect, } - if err := ssoSetWebSessionAndRedirectURL(w, r, res, response.Req.CSRFToken != ""); err != nil { + if err := SSOSetWebSessionAndRedirectURL(w, r, res, response.Req.CSRFToken != ""); err != nil { logger.WithError(err).Error("Error setting web session.") return client.LoginFailedRedirectURL } - return res.clientRedirectURL + return res.ClientRedirectURL } logger.Debug("Callback redirecting to console login.") diff --git a/lib/web/websuite.go b/lib/web/websuite.go new file mode 100644 index 0000000000000..94863c0ff49c7 --- /dev/null +++ b/lib/web/websuite.go @@ -0,0 +1,750 @@ +package web + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/base32" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" + "os/user" + "testing" + "time" + + proto "github.com/gogo/protobuf/proto" + "github.com/gorilla/websocket" + "github.com/gravitational/roundtrip" + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/breaker" + authproto "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/client/webclient" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/auth/mocku2f" + "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/bpf" + "github.com/gravitational/teleport/lib/client" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/httplib" + "github.com/gravitational/teleport/lib/httplib/csrf" + "github.com/gravitational/teleport/lib/pam" + "github.com/gravitational/teleport/lib/plugin" + restricted "github.com/gravitational/teleport/lib/restrictedsession" + "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/srv/regular" + "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +const hostID = "00000000-0000-0000-0000-000000000000" + +// TestWebSuite is a suite of components for testing the web package. It exists +// as an exported struct not in a _test.go file so that it can be used from +// outside this package by external packages that extend the web API (such as +// SAML and OIDC auth connectors in the enterprise edition). +type TestWebSuite struct { + Ctx context.Context + Cancel context.CancelFunc + + Node *regular.Server + Proxy *regular.Server + ProxyTunnel reversetunnel.Server + SrvID string + + User string + WebServer *httptest.Server + + MockU2F *mocku2f.Key + Server *auth.TestServer + ProxyClient *auth.Client + Clock clockwork.FakeClock +} + +type TestWebSuiteConfig struct { + AssetDir string + PluginRegistry plugin.Registry +} + +type TestWebSuiteOption func(cfg *TestWebSuiteConfig) + +// WithWebSuiteAssetDir configures a TestWebSuite with an asset directory +// other than the default of ../../webassets/teleport, as this path only +// works from one level of the directory hierarchy. +func WithWebSuiteAssetDir(dir string) TestWebSuiteOption { + return func(cfg *TestWebSuiteConfig) { + cfg.AssetDir = dir + } +} + +// WithWebSuitePluginRegistry configures a TestWebSuite with a plugin +// registry for the web.Handler created for the test suite, allowing external +// plugins to configure a web suite for testing +func WithWebSuitePluginRegistry(reg plugin.Registry) TestWebSuiteOption { + return func(cfg *TestWebSuiteConfig) { + cfg.PluginRegistry = reg + } +} + +func NewTestWebSuite(t *testing.T, opts ...TestWebSuiteOption) *TestWebSuite { + cfg := &TestWebSuiteConfig{ + AssetDir: "../../websuite/teleport", + } + for _, opt := range opts { + opt(cfg) + } + + mockU2F, err := mocku2f.Create() + require.NoError(t, err) + require.NotNil(t, mockU2F) + + u, err := user.Current() + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + s := &TestWebSuite{ + MockU2F: mockU2F, + Clock: clockwork.NewFakeClock(), + User: u.Username, + Ctx: ctx, + Cancel: cancel, + } + + networkingConfig, err := types.NewClusterNetworkingConfigFromConfigFile(types.ClusterNetworkingConfigSpecV2{ + KeepAliveInterval: types.Duration(10 * time.Second), + }) + require.NoError(t, err) + + s.Server, err = auth.NewTestServer(auth.TestServerConfig{ + Auth: auth.TestAuthServerConfig{ + ClusterName: "localhost", + Dir: t.TempDir(), + Clock: s.Clock, + ClusterNetworkingConfig: networkingConfig, + }, + }) + require.NoError(t, err) + + // Register the auth server, since test auth server doesn't start its own + // heartbeat. + err = s.Server.Auth().UpsertAuthServer(&types.ServerV2{ + Kind: types.KindAuthServer, + Version: types.V2, + Metadata: types.Metadata{ + Namespace: apidefaults.Namespace, + Name: "auth", + }, + Spec: types.ServerSpecV2{ + Addr: s.Server.TLS.Listener.Addr().String(), + Hostname: "localhost", + Version: teleport.Version, + }, + }) + require.NoError(t, err) + + priv, pub, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + + tlsPub, err := auth.PrivateKeyToPublicKeyTLS(priv) + require.NoError(t, err) + + // start node + certs, err := s.Server.Auth().GenerateHostCerts(s.Ctx, + &authproto.HostCertsRequest{ + HostID: hostID, + NodeName: s.Server.ClusterName(), + Role: types.RoleNode, + PublicSSHKey: pub, + PublicTLSKey: tlsPub, + }) + require.NoError(t, err) + + signer, err := sshutils.NewSigner(priv, certs.SSH) + require.NoError(t, err) + + nodeID := "node" + nodeClient, err := s.Server.NewClient(auth.TestIdentity{ + I: auth.BuiltinRole{ + Role: types.RoleNode, + Username: nodeID, + }, + }) + require.NoError(t, err) + + nodeLockWatcher, err := services.NewLockWatcher(s.Ctx, services.LockWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentNode, + Client: nodeClient, + }, + }) + require.NoError(t, err) + + // create SSH service: + nodeDataDir := t.TempDir() + node, err := regular.New( + utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, + s.Server.ClusterName(), + []ssh.Signer{signer}, + nodeClient, + nodeDataDir, + "", + utils.NetAddr{}, + nodeClient, + regular.SetUUID(nodeID), + regular.SetNamespace(apidefaults.Namespace), + regular.SetShell("/bin/sh"), + regular.SetEmitter(nodeClient), + regular.SetPAMConfig(&pam.Config{Enabled: false}), + regular.SetBPF(&bpf.NOP{}), + regular.SetRestrictedSessionManager(&restricted.NOP{}), + regular.SetClock(s.Clock), + regular.SetLockWatcher(nodeLockWatcher), + ) + require.NoError(t, err) + s.Node = node + s.SrvID = node.ID() + require.NoError(t, s.Node.Start()) + require.NoError(t, auth.CreateUploaderDir(nodeDataDir)) + + // create reverse tunnel service: + proxyID := "proxy" + s.ProxyClient, err = s.Server.NewClient(auth.TestIdentity{ + I: auth.BuiltinRole{ + Role: types.RoleProxy, + Username: proxyID, + }, + }) + require.NoError(t, err) + + revTunListener, err := net.Listen("tcp", fmt.Sprintf("%v:0", s.Server.ClusterName())) + require.NoError(t, err) + + proxyLockWatcher, err := services.NewLockWatcher(s.Ctx, services.LockWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.ProxyClient, + }, + }) + require.NoError(t, err) + + proxyNodeWatcher, err := services.NewNodeWatcher(s.Ctx, services.NodeWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.ProxyClient, + }, + }) + require.NoError(t, err) + + caWatcher, err := services.NewCertAuthorityWatcher(s.Ctx, services.CertAuthorityWatcherConfig{ + ResourceWatcherConfig: services.ResourceWatcherConfig{ + Component: teleport.ComponentProxy, + Client: s.ProxyClient, + }, + Types: []types.CertAuthType{types.HostCA, types.UserCA}, + }) + require.NoError(t, err) + defer caWatcher.Close() + + revTunServer, err := reversetunnel.NewServer(reversetunnel.Config{ + ID: node.ID(), + Listener: revTunListener, + ClientTLS: s.ProxyClient.TLSConfig(), + ClusterName: s.Server.ClusterName(), + HostSigners: []ssh.Signer{signer}, + LocalAuthClient: s.ProxyClient, + LocalAccessPoint: s.ProxyClient, + Emitter: s.ProxyClient, + NewCachingAccessPoint: noCache, + DataDir: t.TempDir(), + LockWatcher: proxyLockWatcher, + NodeWatcher: proxyNodeWatcher, + CertAuthorityWatcher: caWatcher, + CircuitBreakerConfig: breaker.NoopBreakerConfig(), + LocalAuthAddresses: []string{s.Server.TLS.Listener.Addr().String()}, + }) + require.NoError(t, err) + s.ProxyTunnel = revTunServer + + // proxy server: + s.Proxy, err = regular.New( + utils.NetAddr{AddrNetwork: "tcp", Addr: "127.0.0.1:0"}, + s.Server.ClusterName(), + []ssh.Signer{signer}, + s.ProxyClient, + t.TempDir(), + "", + utils.NetAddr{}, + s.ProxyClient, + regular.SetUUID(proxyID), + regular.SetProxyMode("", revTunServer, s.ProxyClient), + regular.SetEmitter(s.ProxyClient), + regular.SetNamespace(apidefaults.Namespace), + regular.SetBPF(&bpf.NOP{}), + regular.SetRestrictedSessionManager(&restricted.NOP{}), + regular.SetClock(s.Clock), + regular.SetLockWatcher(proxyLockWatcher), + regular.SetNodeWatcher(proxyNodeWatcher), + ) + require.NoError(t, err) + + // Expired sessions are purged immediately + var sessionLingeringThreshold time.Duration + fs, err := NewDebugFileSystem(cfg.AssetDir) + require.NoError(t, err) + handler, err := NewHandler(Config{ + Proxy: revTunServer, + AuthServers: utils.FromAddr(s.Server.TLS.Addr()), + DomainName: s.Server.ClusterName(), + ProxyClient: s.ProxyClient, + CipherSuites: utils.DefaultCipherSuites(), + AccessPoint: s.ProxyClient, + Context: s.Ctx, + HostUUID: proxyID, + Emitter: s.ProxyClient, + StaticFS: fs, + CachedSessionLingeringThreshold: &sessionLingeringThreshold, + ProxySettings: &mockProxySettings{}, + PluginRegistry: cfg.PluginRegistry, + }, SetSessionStreamPollPeriod(200*time.Millisecond), SetClock(s.Clock)) + require.NoError(t, err) + + s.WebServer = httptest.NewUnstartedServer(handler) + s.WebServer.StartTLS() + err = s.Proxy.Start() + require.NoError(t, err) + + // Wait for proxy to fully register before starting the test. + for start := time.Now(); ; { + proxies, err := s.ProxyClient.GetProxies() + require.NoError(t, err) + if len(proxies) != 0 { + break + } + if time.Since(start) > 5*time.Second { + t.Fatal("proxy didn't register within 5s after startup") + } + } + + proxyAddr := utils.MustParseAddr(s.Proxy.Addr()) + + addr := utils.MustParseAddr(s.WebServer.Listener.Addr().String()) + handler.handler.cfg.ProxyWebAddr = *addr + handler.handler.cfg.ProxySSHAddr = *proxyAddr + _, sshPort, err := net.SplitHostPort(proxyAddr.String()) + require.NoError(t, err) + handler.handler.sshPort = sshPort + + t.Cleanup(func() { + // In particular close the lock watchers by canceling the context. + s.Cancel() + + s.WebServer.Close() + + var errors []error + if err := s.ProxyTunnel.Close(); err != nil { + errors = append(errors, err) + } + if err := s.Node.Close(); err != nil { + errors = append(errors, err) + } + s.WebServer.Close() + if err := s.Proxy.Close(); err != nil { + errors = append(errors, err) + } + if err := s.Server.Shutdown(context.Background()); err != nil { + errors = append(errors, err) + } + require.Empty(t, errors) + }) + + return s +} + +func noCache(clt auth.ClientI, cacheName []string) (auth.RemoteProxyAccessPoint, error) { + return clt, nil +} + +func (r *authPack) renewSession(ctx context.Context, t *testing.T) *roundtrip.Response { + resp, err := r.clt.PostJSON(ctx, r.clt.Endpoint("webapi", "sessions", "renew"), nil) + require.NoError(t, err) + return resp +} + +func (r *authPack) validateAPI(ctx context.Context, t *testing.T) { + _, err := r.clt.Get(ctx, r.clt.Endpoint("webapi", "sites"), url.Values{}) + require.NoError(t, err) +} + +type authPack struct { + otpSecret string + user string + login string + password string + session *CreateSessionResponse + clt *client.WebClient + cookies []*http.Cookie +} + +// authPack returns new authenticated package consisting of created valid +// user, otp token, created web session and authenticated client. +func (s *TestWebSuite) authPack(t *testing.T, user string) *authPack { + login := s.User + pass := "abc123" + rawSecret := "def456" + otpSecret := base32.StdEncoding.EncodeToString([]byte(rawSecret)) + + ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{ + Type: constants.Local, + SecondFactor: constants.SecondFactorOTP, + }) + require.NoError(t, err) + err = s.Server.Auth().SetAuthPreference(s.Ctx, ap) + require.NoError(t, err) + + s.createUser(t, user, login, pass, otpSecret) + + // create a valid otp token + validToken, err := totp.GenerateCode(otpSecret, s.Clock.Now()) + require.NoError(t, err) + + clt := s.client() + req := CreateSessionReq{ + User: user, + Pass: pass, + SecondFactorToken: validToken, + } + + csrfToken := "2ebcb768d0090ea4368e42880c970b61865c326172a4a2343b645cf5d7f20992" + re, err := s.login(clt, csrfToken, csrfToken, req) + require.NoError(t, err) + + var rawSess *CreateSessionResponse + require.NoError(t, json.Unmarshal(re.Bytes(), &rawSess)) + + sess, err := rawSess.response() + require.NoError(t, err) + + jar, err := cookiejar.New(nil) + require.NoError(t, err) + + clt = s.client(roundtrip.BearerAuth(sess.Token), roundtrip.CookieJar(jar)) + jar.SetCookies(s.url(), re.Cookies()) + + return &authPack{ + otpSecret: otpSecret, + user: user, + login: login, + session: sess, + clt: clt, + cookies: re.Cookies(), + } +} + +func (s *TestWebSuite) createUser(t *testing.T, user string, login string, pass string, otpSecret string) { + teleUser, err := types.NewUser(user) + require.NoError(t, err) + role := services.RoleForUser(teleUser) + role.SetLogins(types.Allow, []string{login}) + options := role.GetOptions() + options.ForwardAgent = types.NewBool(true) + role.SetOptions(options) + err = s.Server.Auth().UpsertRole(s.Ctx, role) + require.NoError(t, err) + teleUser.AddRole(role.GetName()) + + teleUser.SetCreatedBy(types.CreatedBy{ + User: types.UserRef{Name: "some-auth-user"}, + }) + err = s.Server.Auth().CreateUser(s.Ctx, teleUser) + require.NoError(t, err) + + err = s.Server.Auth().UpsertPassword(user, []byte(pass)) + require.NoError(t, err) + + if otpSecret != "" { + dev, err := services.NewTOTPDevice("otp", otpSecret, s.Clock.Now()) + require.NoError(t, err) + err = s.Server.Auth().UpsertMFADevice(context.Background(), user, dev) + require.NoError(t, err) + } +} + +func (s *TestWebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOpt) (*websocket.Conn, error) { + req := TerminalRequest{ + Server: s.SrvID, + Login: pack.login, + Term: session.TerminalParams{ + W: 100, + H: 100, + }, + SessionID: session.NewID(), + } + for _, opt := range opts { + opt(&req) + } + + u := url.URL{ + Host: s.url().Host, + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut), + } + data, err := json.Marshal(req) + if err != nil { + return nil, err + } + + q := u.Query() + q.Set("params", string(data)) + q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) + u.RawQuery = q.Encode() + + dialer := websocket.Dialer{} + dialer.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + header := http.Header{} + header.Add("Origin", "http://localhost") + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + ws, resp, err := dialer.Dial(u.String(), header) + if err != nil { + return nil, trace.Wrap(err) + } + + require.NoError(t, resp.Body.Close()) + return ws, nil +} + +func (s *TestWebSuite) waitForRawEvent(ws *websocket.Conn, timeout time.Duration) error { + timeoutContext, timeoutCancel := context.WithTimeout(s.Ctx, timeout) + defer timeoutCancel() + + done := make(chan error, 1) + + go func() { + for { + ty, raw, err := ws.ReadMessage() + if err != nil { + done <- trace.Wrap(err) + return + } + + if ty != websocket.BinaryMessage { + done <- trace.BadParameter("expected binary message, got %v", ty) + return + } + + var envelope Envelope + err = proto.Unmarshal(raw, &envelope) + if err != nil { + done <- trace.Wrap(err) + return + } + + if envelope.GetType() == defaults.WebsocketRaw { + done <- nil + return + } + } + }() + + for { + select { + case <-timeoutContext.Done(): + return trace.BadParameter("timeout waiting for raw event") + case err := <-done: + return trace.Wrap(err) + } + } +} + +func (s *TestWebSuite) waitForResizeEvent(ws *websocket.Conn, timeout time.Duration) error { + timeoutContext, timeoutCancel := context.WithTimeout(s.Ctx, timeout) + defer timeoutCancel() + + done := make(chan error, 1) + + go func() { + for { + ty, raw, err := ws.ReadMessage() + if err != nil { + done <- trace.Wrap(err) + return + } + + if ty != websocket.BinaryMessage { + done <- trace.BadParameter("expected binary message, got %v", ty) + return + } + + var envelope Envelope + err = proto.Unmarshal(raw, &envelope) + if err != nil { + done <- trace.Wrap(err) + return + } + + if envelope.GetType() != defaults.WebsocketAudit { + continue + } + + var e events.EventFields + err = json.Unmarshal([]byte(envelope.GetPayload()), &e) + if err != nil { + done <- trace.Wrap(err) + return + } + + if e.GetType() == events.ResizeEvent { + done <- nil + return + } + } + }() + + for { + select { + case <-timeoutContext.Done(): + return trace.BadParameter("timeout waiting for resize event") + case err := <-done: + return trace.Wrap(err) + } + } +} + +func (s *TestWebSuite) listenForResizeEvent(ws *websocket.Conn) chan struct{} { + ch := make(chan struct{}) + + go func() { + for { + ty, raw, err := ws.ReadMessage() + if err != nil { + close(ch) + return + } + + if ty != websocket.BinaryMessage { + close(ch) + return + } + + var envelope Envelope + err = proto.Unmarshal(raw, &envelope) + if err != nil { + close(ch) + return + } + + if envelope.GetType() != defaults.WebsocketAudit { + continue + } + + var e events.EventFields + err = json.Unmarshal([]byte(envelope.GetPayload()), &e) + if err != nil { + close(ch) + return + } + + if e.GetType() == events.ResizeEvent { + ch <- struct{}{} + return + } + } + }() + + return ch +} + +func (s *TestWebSuite) ClientNoRedirects(opts ...roundtrip.ClientParam) *client.WebClient { + hclient := client.NewInsecureWebClient() + hclient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + opts = append(opts, roundtrip.HTTPClient(hclient)) + wc, err := client.NewWebClient(s.url().String(), opts...) + if err != nil { + panic(err) + } + return wc +} + +func (s *TestWebSuite) client(opts ...roundtrip.ClientParam) *client.WebClient { + opts = append(opts, roundtrip.HTTPClient(client.NewInsecureWebClient())) + wc, err := client.NewWebClient(s.url().String(), opts...) + if err != nil { + panic(err) + } + return wc +} + +func (s *TestWebSuite) login(clt *client.WebClient, cookieToken string, reqToken string, reqData interface{}) (*roundtrip.Response, error) { + return httplib.ConvertResponse(clt.RoundTrip(func() (*http.Response, error) { + data, err := json.Marshal(reqData) + if err != nil { + return nil, err + } + req, err := http.NewRequest("POST", clt.Endpoint("webapi", "sessions"), bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + AddCSRFCookieToReq(req, cookieToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set(csrf.HeaderName, reqToken) + return clt.HTTPClient().Do(req) + })) +} + +func (s *TestWebSuite) url() *url.URL { + u, err := url.Parse("https://" + s.WebServer.Listener.Addr().String()) + if err != nil { + panic(err) + } + return u +} + +func (r CreateSessionResponse) response() (*CreateSessionResponse, error) { + return &CreateSessionResponse{TokenType: r.TokenType, Token: r.Token, TokenExpiresIn: r.TokenExpiresIn, SessionInactiveTimeoutMS: r.SessionInactiveTimeoutMS}, nil +} + +type mockProxySettings struct{} + +func (mock *mockProxySettings) GetProxySettings(ctx context.Context) (*webclient.ProxySettings, error) { + return &webclient.ProxySettings{}, nil +} + +type terminalOpt func(t *TerminalRequest) + +func withSessionID(sid session.ID) terminalOpt { + return func(t *TerminalRequest) { t.SessionID = sid } +} + +func withKeepaliveInterval(d time.Duration) terminalOpt { + return func(t *TerminalRequest) { t.KeepAliveInterval = d } +} + +func AddCSRFCookieToReq(req *http.Request, token string) { + cookie := &http.Cookie{ + Name: csrf.CookieName, + Value: token, + } + + req.AddCookie(cookie) +}