Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions api/types/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ type WebSession interface {
WithoutSecrets() WebSession
// String returns string representation of the session.
String() string
// SetConsumedAccessRequestID sets the ID of the access request from which additional roles to assume were obtained.
SetConsumedAccessRequestID(string)
// GetConsumedAccessRequestID returns the ID of the access request from which additional roles to assume were obtained.
GetConsumedAccessRequestID() string
}

// NewWebSession returns new instance of the web session based on the V2 spec
Expand Down Expand Up @@ -172,6 +176,16 @@ func (ws *WebSessionV2) WithoutSecrets() WebSession {
return ws
}

// SetConsumedAccessRequestID sets the ID of the access request from which additional roles to assume were obtained.
func (ws *WebSessionV2) SetConsumedAccessRequestID(requestID string) {
ws.Spec.ConsumedAccessRequestID = requestID
}

// GetConsumedAccessRequestID returns the ID of the access request from which additional roles to assume were obtained.
func (ws *WebSessionV2) GetConsumedAccessRequestID() string {
return ws.Spec.ConsumedAccessRequestID
}

// setStaticFields sets static resource header and metadata fields.
func (ws *WebSessionV2) setStaticFields() {
ws.Version = V2
Expand Down
1,556 changes: 802 additions & 754 deletions api/types/types.pb.go

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions api/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,10 @@ message WebSessionSpecV2 {
// IdleTimeout is the max time a user can be inactive in a session.
int64 IdleTimeout = 9
[ (gogoproto.jsontag) = "idle_timeout", (gogoproto.casttype) = "Duration" ];
// ConsumedAccessRequestID is the ID of the access request from which additional roles to assume
// were obtained.
string ConsumedAccessRequestID = 10
[ (gogoproto.jsontag) = "consumed_access_request_id,omitempty" ];
}

// WebSessionFilter encodes cache watch parameters for filtering web sessions.
Expand Down
9 changes: 2 additions & 7 deletions lib/auth/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -796,24 +796,19 @@ func (s *APIServer) createWebSession(auth ClientI, w http.ResponseWriter, r *htt
return nil, trace.Wrap(err)
}

// DELETE IN 8.0: proxy v5 sends request with no user field.
// And since proxy v6, request will come with user field set, so grabbing user
// by param is not required.
if req.User == "" {
req.User = p.ByName("user")
}

if req.PrevSessionID != "" {
sess, err := auth.ExtendWebSession(r.Context(), req)
if err != nil {
return nil, trace.Wrap(err)
}
return sess, nil
}

sess, err := auth.CreateWebSession(r.Context(), req.User)
if err != nil {
return nil, trace.Wrap(err)
}

return rawMessage(services.MarshalWebSession(sess, services.WithVersion(version)))
}

Expand Down
2 changes: 2 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1977,6 +1977,8 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi
// Keep preserving the login time.
sess.SetLoginTime(prevSession.GetLoginTime())

sess.SetConsumedAccessRequestID(req.AccessRequestID)

if err := a.upsertWebSession(context.TODO(), req.User, sess); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
2 changes: 2 additions & 0 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ func (h *Handler) getUserContext(w http.ResponseWriter, r *http.Request, p httpr
return nil, trace.Wrap(err)
}

userContext.ConsumedAccessRequestID = c.session.GetConsumedAccessRequestID()

return userContext, nil
}

Expand Down
2 changes: 1 addition & 1 deletion lib/web/apiserver_login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ func configureClusterForMFA(t *testing.T, env *webPack, spec *types.AuthPreferen
// Create user.
const user = "llama"
const password = "password"
env.proxies[0].createUser(ctx, t, user, "root", "password", "" /* otpSecret */)
env.proxies[0].createUser(ctx, t, user, "root", "password", "" /* otpSecret */, nil /* roles */)

// Register device.
clt, err := env.server.NewClient(auth.TestUser(user))
Expand Down
132 changes: 103 additions & 29 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,8 @@ func TestValidateBearerToken(t *testing.T) {
t.Parallel()
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack1 := proxy.authPack(t, "user1")
pack2 := proxy.authPack(t, "user2")
pack1 := proxy.authPack(t, "user1", nil /* roles */)
pack2 := proxy.authPack(t, "user2", nil /* roles */)

// Swap pack1's session token with pack2's sessionToken
jar, err := cookiejar.New(nil)
Expand Down Expand Up @@ -833,7 +833,7 @@ func TestClusterNodesGet(t *testing.T) {
t.Parallel()
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "test-user@example.com")
pack := proxy.authPack(t, "test-user@example.com", nil /* roles */)
clusterName := env.server.ClusterName()

endpoint := pack.clt.Endpoint("webapi", "sites", clusterName, "nodes")
Expand Down Expand Up @@ -1180,7 +1180,7 @@ func TestTerminalRequireSessionMfa(t *testing.T) {
ctx := context.Background()
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "llama")
pack := proxy.authPack(t, "llama", nil /* roles */)

clt, err := env.server.NewClient(auth.TestUser("llama"))
require.NoError(t, err)
Expand Down Expand Up @@ -1427,7 +1427,7 @@ func TestDesktopAccessMFARequiresMfa(t *testing.T) {
ctx := context.Background()
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "llama")
pack := proxy.authPack(t, "llama", nil /* roles */)

clt, err := env.server.NewClient(auth.TestUser("llama"))
require.NoError(t, err)
Expand Down Expand Up @@ -1705,7 +1705,7 @@ func TestCreateSession(t *testing.T) {
env := newWebPack(t, 1)
proxy := env.proxies[0]
user := "test-user@example.com"
pack := proxy.authPack(t, user)
pack := proxy.authPack(t, user, nil /* roles */)

// get site nodes
re, err := pack.clt.Get(context.Background(), pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "nodes"), url.Values{})
Expand Down Expand Up @@ -2383,7 +2383,7 @@ func TestTokenGeneration(t *testing.T) {
env := newWebPack(t, 1)

proxy := env.proxies[0]
pack := proxy.authPack(t, "test-user@example.com")
pack := proxy.authPack(t, "test-user@example.com", nil /* roles */)

endpoint := pack.clt.Endpoint("webapi", "token")
re, err := pack.clt.PostJSON(context.Background(), endpoint, types.ProvisionTokenSpecV2{
Expand Down Expand Up @@ -2422,7 +2422,7 @@ func TestClusterDatabasesGet(t *testing.T) {
env := newWebPack(t, 1)

proxy := env.proxies[0]
pack := proxy.authPack(t, "test-user@example.com")
pack := proxy.authPack(t, "test-user@example.com", nil /* roles */)

endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "databases")
re, err := pack.clt.Get(context.Background(), endpoint, url.Values{})
Expand Down Expand Up @@ -2474,7 +2474,7 @@ func TestClusterKubesGet(t *testing.T) {
env := newWebPack(t, 1)

proxy := env.proxies[0]
pack := proxy.authPack(t, "test-user@example.com")
pack := proxy.authPack(t, "test-user@example.com", nil /* roles */)

endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "kubernetes")
re, err := pack.clt.Get(context.Background(), endpoint, url.Values{})
Expand Down Expand Up @@ -2528,7 +2528,7 @@ func TestClusterAppsGet(t *testing.T) {
env := newWebPack(t, 1)

proxy := env.proxies[0]
pack := proxy.authPack(t, "test-user@example.com")
pack := proxy.authPack(t, "test-user@example.com", nil /* roles */)

type testResponse struct {
Items []ui.App `json:"items"`
Expand Down Expand Up @@ -2593,7 +2593,7 @@ func TestApplicationAccessDisabled(t *testing.T) {
env := newWebPack(t, 1)

proxy := env.proxies[0]
pack := proxy.authPack(t, "foo@example.com")
pack := proxy.authPack(t, "foo@example.com", nil /* roles */)

// Register an application.
app, err := types.NewAppV3(types.Metadata{
Expand Down Expand Up @@ -2624,7 +2624,7 @@ func TestApplicationWebSessionsDeletedAfterLogout(t *testing.T) {
env := newWebPack(t, 1)

proxy := env.proxies[0]
pack := proxy.authPack(t, "foo@example.com")
pack := proxy.authPack(t, "foo@example.com", nil /* roles */)

// Register multiple applications.
applications := []struct {
Expand Down Expand Up @@ -2682,7 +2682,7 @@ func TestCreatePrivilegeToken(t *testing.T) {
proxy := env.proxies[0]

// Create a user with second factor totp.
pack := proxy.authPack(t, "foo@example.com")
pack := proxy.authPack(t, "foo@example.com", nil /* roles */)

// Get a totp code.
totpCode, err := totp.GenerateCode(pack.otpSecret, env.clock.Now().Add(30*time.Second))
Expand All @@ -2705,7 +2705,7 @@ func TestAddMFADevice(t *testing.T) {
ctx := context.Background()
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "foo@example.com")
pack := proxy.authPack(t, "foo@example.com", nil /* roles */)

// Enable second factor.
ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{
Expand Down Expand Up @@ -2832,7 +2832,7 @@ func TestDeleteMFA(t *testing.T) {
ctx := context.Background()
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "foo@example.com")
pack := proxy.authPack(t, "foo@example.com", nil /* roles */)

//setting up client manually because we need sanitizer off
jar, err := cookiejar.New(nil)
Expand Down Expand Up @@ -2878,7 +2878,7 @@ func TestGetMFADevicesWithAuth(t *testing.T) {
t.Parallel()
env := newWebPack(t, 1)
proxy := env.proxies[0]
pack := proxy.authPack(t, "foo@example.com")
pack := proxy.authPack(t, "foo@example.com", nil /* roles */)

endpoint := pack.clt.Endpoint("webapi", "mfa", "devices")
re, err := pack.clt.Get(context.Background(), endpoint, url.Values{})
Expand All @@ -2898,7 +2898,7 @@ func TestGetAndDeleteMFADevices_WithRecoveryApprovedToken(t *testing.T) {

// Create a user with a TOTP device.
username := "llama"
proxy.createUser(ctx, t, username, "root", "password", "some-otp-secret")
proxy.createUser(ctx, t, username, "root", "password", "some-otp-secret", nil /* roles */)

// Enable second factor.
ap, err := types.NewAuthPreference(types.AuthPreferenceSpecV2{
Expand Down Expand Up @@ -2953,7 +2953,7 @@ func TestCreateAuthenticateChallenge(t *testing.T) {
proxy := env.proxies[0]

// Create a user with a TOTP device, with second factor preference to OTP only.
authPack := proxy.authPack(t, "llama@example.com")
authPack := proxy.authPack(t, "llama@example.com", nil /* roles */)

// Authenticated client for private endpoints.
authnClt := authPack.clt
Expand Down Expand Up @@ -3250,7 +3250,7 @@ func TestNewSessionResponseWithRenewSession(t *testing.T) {
require.NoError(t, env.server.Auth().SetClusterNetworkingConfig(context.Background(), cfg))

proxy := env.proxies[0]
pack := proxy.authPack(t, "foo")
pack := proxy.authPack(t, "foo", nil /* roles */)

var ns *CreateSessionResponse
resp := pack.renewSession(context.Background(), t)
Expand All @@ -3273,7 +3273,7 @@ func TestWebSessionsRenewDoesNotBreakExistingTerminalSession(t *testing.T) {

proxy1, proxy2 := env.proxies[0], env.proxies[1]
// Connect to both proxies
pack1 := proxy1.authPack(t, "foo")
pack1 := proxy1.authPack(t, "foo", nil /* roles */)
pack2 := proxy2.authPackFromPack(t, pack1)

ws := proxy2.makeTerminal(t, pack2, session.NewID())
Expand Down Expand Up @@ -3310,7 +3310,7 @@ func TestWebSessionsRenewAllowsOldBearerTokenToLinger(t *testing.T) {
env := newWebPack(t, 1)

proxy := env.proxies[0]
pack := proxy.authPack(t, "foo")
pack := proxy.authPack(t, "foo", nil /* roles */)

delta := 30 * time.Second
// Advance the time before renewing the session.
Expand Down Expand Up @@ -4123,7 +4123,7 @@ type proxy struct {

// authPack returns new authenticated package consisting of created valid
// user, otp token, created web session and authenticated client.
func (r *proxy) authPack(t *testing.T, user string) *authPack {
func (r *proxy) authPack(t *testing.T, user string, roles []types.Role) *authPack {
ctx := context.Background()
const (
loginUser = "user"
Expand All @@ -4141,7 +4141,7 @@ func (r *proxy) authPack(t *testing.T, user string) *authPack {
err = r.auth.Auth().SetAuthPreference(ctx, ap)
require.NoError(t, err)

r.createUser(context.Background(), t, user, loginUser, pass, otpSecret)
r.createUser(context.Background(), t, user, loginUser, pass, otpSecret, roles)

// create a valid otp token
validToken, err := totp.GenerateCode(otpSecret, r.clock.Now())
Expand Down Expand Up @@ -4214,19 +4214,31 @@ func (r *proxy) authPackFromResponse(t *testing.T, httpResp *roundtrip.Response)
}
}

func (r *proxy) createUser(ctx context.Context, t *testing.T, user, login, pass, otpSecret string) {
teleUser, err := types.NewUser(user)
require.NoError(t, err)

func defaultRoleForNewUser(teleUser types.User, login string) types.Role {
role := services.RoleForUser(teleUser)
role.SetLogins(types.Allow, []string{login})
role.SetWindowsDesktopLabels(types.Allow, types.Labels{types.Wildcard: {types.Wildcard}})
options := role.GetOptions()
options.ForwardAgent = types.NewBool(true)
role.SetOptions(options)
err = r.auth.Auth().UpsertRole(ctx, role)
return role
}

func (r *proxy) createUser(ctx context.Context, t *testing.T, user, login, pass, otpSecret string, roles []types.Role) {
teleUser, err := types.NewUser(user)
require.NoError(t, err)

teleUser.AddRole(role.GetName())
if len(roles) == 0 {
roles = []types.Role{defaultRoleForNewUser(teleUser, login)}
}

for _, role := range roles {
err = r.auth.Auth().UpsertRole(ctx, role)
require.NoError(t, err)

teleUser.AddRole(role.GetName())
}

teleUser.SetCreatedBy(types.CreatedBy{
User: types.UserRef{Name: "some-auth-user"},
})
Expand Down Expand Up @@ -4362,3 +4374,65 @@ type mockProxySettings struct {
func (mock *mockProxySettings) GetProxySettings(ctx context.Context) (*webclient.ProxySettings, error) {
return &webclient.ProxySettings{}, nil
}

// TestUserContextWithAccessRequest checks that the userContext includes the ID of the
// access request after it has been consumed and the web session has been renewed.
func TestUserContextWithAccessRequest(t *testing.T) {
t.Parallel()
env := newWebPack(t, 1)
proxy := env.proxies[0]
ctx := context.Background()

// Set user and role names.
username := "user"
baseRoleName := "role"
requestableRolename := "requestable-role"

// Create user's base role with the ability to request the requestable role.
baseRole, err := types.NewRole(baseRoleName, types.RoleSpecV5{
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
Roles: []string{requestableRolename},
},
},
})
require.NoError(t, err)

// Create user with the base role.
pack := proxy.authPack(t, username, []types.Role{baseRole})

// Create the requestable role.
requestableRole, err := types.NewRole(requestableRolename, types.RoleSpecV5{})
require.NoError(t, err)
err = env.server.Auth().UpsertRole(ctx, requestableRole)
require.NoError(t, err)

// Create and approve an access request for the requestable role.
accessReq, err := services.NewAccessRequest(username, requestableRolename)
require.NoError(t, err)
accessReq.SetState(types.RequestState_APPROVED)
err = env.server.Auth().CreateAccessRequest(ctx, accessReq)
require.NoError(t, err)

// Get the ID of the created and approved access request.
accessRequestID := accessReq.GetMetadata().Name

// Make a request to renew the session with the ID of the access request.
_, err = pack.clt.PostJSON(ctx, pack.clt.Endpoint("webapi", "sessions", "renew"), renewSessionRequest{
AccessRequestID: accessRequestID,
})
require.NoError(t, err)

// Make a request to fetch the userContext.
endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "context")
response, err := pack.clt.Get(context.Background(), endpoint, url.Values{})
require.NoError(t, err)

// Process the JSON response of the request.
var userContext ui.UserContext
err = json.Unmarshal(response.Bytes(), &userContext)
require.NoError(t, err)

// Verify that the userContext returned contains the correct Access Request ID.
require.Equal(t, accessRequestID, userContext.ConsumedAccessRequestID)
}
Loading