Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2364,6 +2364,7 @@ func (h *Handler) githubCallback(w http.ResponseWriter, r *http.Request, p httpr
CSRFToken: response.Req.CSRFToken,
Username: response.Username,
SessionName: response.Session.GetName(),
SessionExpiry: response.Session.Expiry(),
ClientRedirectURL: response.Req.ClientRedirectURL,
}

Expand Down Expand Up @@ -2705,7 +2706,7 @@ func (h *Handler) createWebSession(w http.ResponseWriter, r *http.Request, p htt
return nil, trace.AccessDenied("invalid credentials")
}

if err := websession.SetCookie(w, req.User, webSession.GetName()); err != nil {
if err := websession.SetCookie(w, req.User, webSession.GetName(), webSession.Expiry()); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -2829,7 +2830,7 @@ func (h *Handler) renewWebSession(w http.ResponseWriter, r *http.Request, params
if err != nil {
return nil, trace.Wrap(err)
}
if err := websession.SetCookie(w, newSession.GetUser(), newSession.GetName()); err != nil {
if err := websession.SetCookie(w, newSession.GetUser(), newSession.GetName(), newSession.Expiry()); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -2906,7 +2907,7 @@ func (h *Handler) changeUserAuthentication(w http.ResponseWriter, r *http.Reques
return nil, trace.Wrap(err)
}

if err := websession.SetCookie(w, sess.GetUser(), sess.GetName()); err != nil {
if err := websession.SetCookie(w, sess.GetUser(), sess.GetName(), sess.Expiry()); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -3111,7 +3112,7 @@ func (h *Handler) mfaLoginFinishSession(w http.ResponseWriter, r *http.Request,

// Fetch user from session, user is empty for passwordless requests.
user := session.GetUser()
if err := websession.SetCookie(w, user, session.GetName()); err != nil {
if err := websession.SetCookie(w, user, session.GetName(), session.Expiry()); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -5373,6 +5374,10 @@ type SSOCallbackResponse struct {
// SessionName is the name of the session generated by auth server if
// requested in the SSO request.
SessionName string
// SessionExpiry is the expiration of the session. This is used
// to set the expiration time of the cookie. If no expiraton is set,
// the cookie with be a "session cookie", which is removed when the browser closes.
SessionExpiry time.Time
// ClientRedirectURL is the URL to redirect back to on completion of
// the SSO login process.
ClientRedirectURL string
Expand All @@ -5397,7 +5402,7 @@ func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, resp
}
}

if err := websession.SetCookie(w, response.Username, response.SessionName); err != nil {
if err := websession.SetCookie(w, response.Username, response.SessionName, response.SessionExpiry); err != nil {
return trace.Wrap(err)
}

Expand Down
8 changes: 7 additions & 1 deletion lib/web/session/cookie.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"encoding/hex"
"encoding/json"
"net/http"
"time"
)

// Cookie stores information about active user and session
Expand Down Expand Up @@ -56,7 +57,7 @@ func DecodeCookie(b string) (*Cookie, error) {

// SetCookie encodes the provided user and session id via [EncodeCookie]
// and then sets the [http.Cookie] of the provided [http.ResponseWriter].
func SetCookie(w http.ResponseWriter, user, sid string) error {
func SetCookie(w http.ResponseWriter, user, sid string, expiry time.Time) error {
d, err := EncodeCookie(user, sid)
if err != nil {
return err
Expand All @@ -69,6 +70,11 @@ func SetCookie(w http.ResponseWriter, user, sid string) error {
Secure: true,
SameSite: http.SameSiteLaxMode,
}
// if expiry is zero, we can skip MaxAge and treat as a session cookie.
// Otherwise, set maxage
if !expiry.IsZero() {
c.MaxAge = int(time.Until(expiry).Seconds())
}
http.SetCookie(w, c)
return nil
}
Expand Down
67 changes: 52 additions & 15 deletions lib/web/session/cookie_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package session
import (
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand All @@ -32,23 +33,59 @@ func TestCookies(t *testing.T) {
)
expectedCookie := &Cookie{User: user, SID: sessionID}

encodedCookie, err := EncodeCookie(user, sessionID)
require.NoError(t, err)
t.Run("encode and decode", func(t *testing.T) {
encodedCookie, err := EncodeCookie(user, sessionID)
require.NoError(t, err)

cookie, err := DecodeCookie(encodedCookie)
require.NoError(t, err)
require.Equal(t, expectedCookie, cookie)
cookie, err := DecodeCookie(encodedCookie)
require.NoError(t, err)
require.Equal(t, expectedCookie, cookie)
})

recorder := httptest.NewRecorder()
require.Empty(t, recorder.Header().Get("Set-Cookie"))
tests := []struct {
name string
expiry time.Time
expectClear bool
expectedCookie string
}{
{
name: "valid expiry",
expiry: time.Now().Add(10 * time.Second),
expectClear: true,
expectedCookie: "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; Max-Age=9; HttpOnly; Secure; SameSite=Lax",
},
{
name: "expired cert",
expiry: time.Now().Add(-10 * time.Second),
expectClear: false,
expectedCookie: "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; Max-Age=0; HttpOnly; Secure; SameSite=Lax",
},
{
name: "zero time",
expiry: time.Time{},
expectClear: false,
expectedCookie: "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; HttpOnly; Secure; SameSite=Lax",
},
}

require.NoError(t, SetCookie(recorder, user, sessionID))
ClearCookie(recorder)
setCookies := recorder.Header().Values("Set-Cookie")
require.Len(t, setCookies, 2)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
require.Empty(t, recorder.Header().Get("Set-Cookie"))

// SetCookie will store the encoded session in the cookie
require.Equal(t, "__Host-session=7b2275736572223a226c6c616d61222c22736964223a223938373635227d; Path=/; HttpOnly; Secure; SameSite=Lax", setCookies[0])
// ClearCookie will add an entry with the cookie value cleared out
require.Equal(t, "__Host-session=; Path=/; HttpOnly; Secure; SameSite=Lax", setCookies[1])
require.NoError(t, SetCookie(recorder, user, sessionID, tt.expiry))

if tt.expectClear {
ClearCookie(recorder)
setCookies := recorder.Header().Values("Set-Cookie")
require.Len(t, setCookies, 2)
require.Equal(t, tt.expectedCookie, setCookies[0])
require.Equal(t, "__Host-session=; Path=/; HttpOnly; Secure; SameSite=Lax", setCookies[1])
} else {
setCookies := recorder.Header().Values("Set-Cookie")
require.Len(t, setCookies, 1)
require.Equal(t, tt.expectedCookie, setCookies[0])
}
})
}
}
Loading