diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 85301f637e889..f1b2f795c3149 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -2364,6 +2364,7 @@ func (h *Handler) githubCallback(w http.ResponseWriter, r *http.Request, p httpr CSRFToken: response.Req.CSRFToken, Username: response.Username, SessionName: response.Session.GetName(), + SessionExpiry: response.Session.Expiry(), ClientRedirectURL: response.Req.ClientRedirectURL, } @@ -2705,7 +2706,7 @@ func (h *Handler) createWebSession(w http.ResponseWriter, r *http.Request, p htt return nil, trace.AccessDenied("invalid credentials") } - if err := websession.SetCookie(w, req.User, webSession.GetName()); err != nil { + if err := websession.SetCookie(w, req.User, webSession.GetName(), webSession.Expiry()); err != nil { return nil, trace.Wrap(err) } @@ -2829,7 +2830,7 @@ func (h *Handler) renewWebSession(w http.ResponseWriter, r *http.Request, params if err != nil { return nil, trace.Wrap(err) } - if err := websession.SetCookie(w, newSession.GetUser(), newSession.GetName()); err != nil { + if err := websession.SetCookie(w, newSession.GetUser(), newSession.GetName(), newSession.Expiry()); err != nil { return nil, trace.Wrap(err) } @@ -2906,7 +2907,7 @@ func (h *Handler) changeUserAuthentication(w http.ResponseWriter, r *http.Reques return nil, trace.Wrap(err) } - if err := websession.SetCookie(w, sess.GetUser(), sess.GetName()); err != nil { + if err := websession.SetCookie(w, sess.GetUser(), sess.GetName(), sess.Expiry()); err != nil { return nil, trace.Wrap(err) } @@ -3111,7 +3112,7 @@ func (h *Handler) mfaLoginFinishSession(w http.ResponseWriter, r *http.Request, // Fetch user from session, user is empty for passwordless requests. user := session.GetUser() - if err := websession.SetCookie(w, user, session.GetName()); err != nil { + if err := websession.SetCookie(w, user, session.GetName(), session.Expiry()); err != nil { return nil, trace.Wrap(err) } @@ -5373,6 +5374,10 @@ type SSOCallbackResponse struct { // SessionName is the name of the session generated by auth server if // requested in the SSO request. SessionName string + // SessionExpiry is the expiration of the session. This is used + // to set the expiration time of the cookie. If no expiraton is set, + // the cookie with be a "session cookie", which is removed when the browser closes. + SessionExpiry time.Time // ClientRedirectURL is the URL to redirect back to on completion of // the SSO login process. ClientRedirectURL string @@ -5397,7 +5402,7 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp } } - if err := websession.SetCookie(w, response.Username, response.SessionName); err != nil { + if err := websession.SetCookie(w, response.Username, response.SessionName, response.SessionExpiry); err != nil { return trace.Wrap(err) } diff --git a/lib/web/session/cookie.go b/lib/web/session/cookie.go index df9bb4b538a6b..08800c1800216 100644 --- a/lib/web/session/cookie.go +++ b/lib/web/session/cookie.go @@ -22,6 +22,7 @@ import ( "encoding/hex" "encoding/json" "net/http" + "time" ) // Cookie stores information about active user and session @@ -56,7 +57,7 @@ func DecodeCookie(b string) (*Cookie, error) { // SetCookie encodes the provided user and session id via [EncodeCookie] // and then sets the [http.Cookie] of the provided [http.ResponseWriter]. -func SetCookie(w http.ResponseWriter, user, sid string) error { +func SetCookie(w http.ResponseWriter, user, sid string, expiry time.Time) error { d, err := EncodeCookie(user, sid) if err != nil { return err @@ -69,6 +70,11 @@ func SetCookie(w http.ResponseWriter, user, sid string) error { Secure: true, SameSite: http.SameSiteLaxMode, } + // if expiry is zero, we can skip MaxAge and treat as a session cookie. + // Otherwise, set maxage + if !expiry.IsZero() { + c.MaxAge = int(time.Until(expiry).Seconds()) + } http.SetCookie(w, c) return nil } diff --git a/lib/web/session/cookie_test.go b/lib/web/session/cookie_test.go index 8f7d685033cfe..fd4f74697ffa5 100644 --- a/lib/web/session/cookie_test.go +++ b/lib/web/session/cookie_test.go @@ -21,6 +21,7 @@ package session import ( "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -32,23 +33,59 @@ func TestCookies(t *testing.T) { ) expectedCookie := &Cookie{User: user, SID: sessionID} - encodedCookie, err := EncodeCookie(user, sessionID) - require.NoError(t, err) + t.Run("encode and decode", func(t *testing.T) { + encodedCookie, err := EncodeCookie(user, sessionID) + require.NoError(t, err) - cookie, err := DecodeCookie(encodedCookie) - require.NoError(t, err) - require.Equal(t, expectedCookie, cookie) + cookie, err := DecodeCookie(encodedCookie) + require.NoError(t, err) + require.Equal(t, expectedCookie, cookie) + }) - recorder := httptest.NewRecorder() - require.Empty(t, recorder.Header().Get("Set-Cookie")) + tests := []struct { + name string + expiry time.Time + expectClear bool + expectedCookie string + }{ + { + name: "valid expiry", + expiry: time.Now().Add(10 * time.Second), + expectClear: true, + expectedCookie: "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; Max-Age=9; HttpOnly; Secure; SameSite=Lax", + }, + { + name: "expired cert", + expiry: time.Now().Add(-10 * time.Second), + expectClear: false, + expectedCookie: "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; Max-Age=0; HttpOnly; Secure; SameSite=Lax", + }, + { + name: "zero time", + expiry: time.Time{}, + expectClear: false, + expectedCookie: "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; HttpOnly; Secure; SameSite=Lax", + }, + } - require.NoError(t, SetCookie(recorder, user, sessionID)) - ClearCookie(recorder) - setCookies := recorder.Header().Values("Set-Cookie") - require.Len(t, setCookies, 2) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + recorder := httptest.NewRecorder() + require.Empty(t, recorder.Header().Get("Set-Cookie")) - // SetCookie will store the encoded session in the cookie - require.Equal(t, "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; HttpOnly; Secure; SameSite=Lax", setCookies[0]) - // ClearCookie will add an entry with the cookie value cleared out - require.Equal(t, "__Host-session=; Path=/; HttpOnly; Secure; SameSite=Lax", setCookies[1]) + require.NoError(t, SetCookie(recorder, user, sessionID, tt.expiry)) + + if tt.expectClear { + ClearCookie(recorder) + setCookies := recorder.Header().Values("Set-Cookie") + require.Len(t, setCookies, 2) + require.Equal(t, tt.expectedCookie, setCookies[0]) + require.Equal(t, "__Host-session=; Path=/; HttpOnly; Secure; SameSite=Lax", setCookies[1]) + } else { + setCookies := recorder.Header().Values("Set-Cookie") + require.Len(t, setCookies, 1) + require.Equal(t, tt.expectedCookie, setCookies[0]) + } + }) + } }