Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/proto/teleport/legacy/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2800,6 +2800,9 @@ message WebSessionSpecV2 {
(gogoproto.jsontag) = "idle_timeout",
(gogoproto.casttype) = "Duration"
];
// ConsumedAccessRequestID is the ID of the access request from which additional roles to assume
// were obtained.
string ConsumedAccessRequestID = 10 [(gogoproto.jsontag) = "consumed_access_request_id,omitempty"];
}

// WebSessionFilter encodes cache watch parameters for filtering web sessions.
Expand Down
14 changes: 14 additions & 0 deletions api/types/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ type WebSession interface {
WithoutSecrets() WebSession
// String returns string representation of the session.
String() string
// SetConsumedAccessRequestID sets the ID of the access request from which additional roles to assume were obtained.
SetConsumedAccessRequestID(string)
// GetConsumedAccessRequestID returns the ID of the access request from which additional roles to assume were obtained.
GetConsumedAccessRequestID() string
}

// NewWebSession returns new instance of the web session based on the V2 spec
Expand Down Expand Up @@ -172,6 +176,16 @@ func (ws *WebSessionV2) WithoutSecrets() WebSession {
return ws
}

// SetConsumedAccessRequestID sets the ID of the access request from which additional roles to assume were obtained.
func (ws *WebSessionV2) SetConsumedAccessRequestID(requestID string) {
ws.Spec.ConsumedAccessRequestID = requestID
}

// GetConsumedAccessRequestID returns the ID of the access request from which additional roles to assume were obtained.
func (ws *WebSessionV2) GetConsumedAccessRequestID() string {
return ws.Spec.ConsumedAccessRequestID
}

// setStaticFields sets static resource header and metadata fields.
func (ws *WebSessionV2) setStaticFields() {
ws.Version = V2
Expand Down
1,868 changes: 958 additions & 910 deletions api/types/types.pb.go

Large diffs are not rendered by default.

9 changes: 2 additions & 7 deletions lib/auth/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,24 +516,19 @@ func (s *APIServer) createWebSession(auth ClientI, w http.ResponseWriter, r *htt
return nil, trace.Wrap(err)
}

// DELETE IN 8.0: proxy v5 sends request with no user field.
// And since proxy v6, request will come with user field set, so grabbing user
// by param is not required.
if req.User == "" {
req.User = p.ByName("user")
}

if req.PrevSessionID != "" {
sess, err := auth.ExtendWebSession(r.Context(), req)
if err != nil {
return nil, trace.Wrap(err)
}
return sess, nil
}

sess, err := auth.CreateWebSession(r.Context(), req.User)
if err != nil {
return nil, trace.Wrap(err)
}

return rawMessage(services.MarshalWebSession(sess, services.WithVersion(version)))
}

Expand Down
2 changes: 2 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1937,6 +1937,8 @@ func (a *Server) ExtendWebSession(ctx context.Context, req WebSessionReq, identi
// Keep preserving the login time.
sess.SetLoginTime(prevSession.GetLoginTime())

sess.SetConsumedAccessRequestID(req.AccessRequestID)

if err := a.upsertWebSession(ctx, req.User, sess); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
2 changes: 2 additions & 0 deletions lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,8 @@ func (h *Handler) getUserContext(w http.ResponseWriter, r *http.Request, p httpr
return nil, trace.Wrap(err)
}

userContext.ConsumedAccessRequestID = c.session.GetConsumedAccessRequestID()

return userContext, nil
}

Expand Down
62 changes: 62 additions & 0 deletions lib/web/apiserver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4720,3 +4720,65 @@ type mockProxySettings struct{}
func (mock *mockProxySettings) GetProxySettings(ctx context.Context) (*webclient.ProxySettings, error) {
return &webclient.ProxySettings{}, nil
}

// TestUserContextWithAccessRequest checks that the userContext includes the ID of the
// access request after it has been consumed and the web session has been renewed.
func TestUserContextWithAccessRequest(t *testing.T) {
t.Parallel()
env := newWebPack(t, 1)
proxy := env.proxies[0]
ctx := context.Background()

// Set user and role names.
username := "user"
baseRoleName := "role"
requestableRolename := "requestable-role"

// Create user's base role with the ability to request the requestable role.
baseRole, err := types.NewRole(baseRoleName, types.RoleSpecV5{
Allow: types.RoleConditions{
Request: &types.AccessRequestConditions{
Roles: []string{requestableRolename},
},
},
})
require.NoError(t, err)

// Create user with the base role.
pack := proxy.authPack(t, username, []types.Role{baseRole})

// Create the requestable role.
requestableRole, err := types.NewRole(requestableRolename, types.RoleSpecV5{})
require.NoError(t, err)
err = env.server.Auth().UpsertRole(ctx, requestableRole)
require.NoError(t, err)

// Create and approve an access request for the requestable role.
accessReq, err := services.NewAccessRequest(username, requestableRolename)
require.NoError(t, err)
accessReq.SetState(types.RequestState_APPROVED)
err = env.server.Auth().CreateAccessRequest(ctx, accessReq)
require.NoError(t, err)

// Get the ID of the created and approved access request.
accessRequestID := accessReq.GetMetadata().Name

// Make a request to renew the session with the ID of the access request.
_, err = pack.clt.PostJSON(ctx, pack.clt.Endpoint("webapi", "sessions", "renew"), renewSessionRequest{
AccessRequestID: accessRequestID,
})
require.NoError(t, err)

// Make a request to fetch the userContext.
endpoint := pack.clt.Endpoint("webapi", "sites", env.server.ClusterName(), "context")
response, err := pack.clt.Get(context.Background(), endpoint, url.Values{})
require.NoError(t, err)

// Process the JSON response of the request.
var userContext ui.UserContext
err = json.Unmarshal(response.Bytes(), &userContext)
require.NoError(t, err)

// Verify that the userContext returned contains the correct Access Request ID.
require.Equal(t, accessRequestID, userContext.ConsumedAccessRequestID)
}
1 change: 1 addition & 0 deletions lib/web/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ func (c *SessionContext) extendWebSession(ctx context.Context, accessRequestID s
if err != nil {
return nil, trace.Wrap(err)
}

return session, nil
}

Expand Down
3 changes: 3 additions & 0 deletions lib/web/ui/usercontext.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ type UserContext struct {
AccessStrategy accessStrategy `json:"accessStrategy"`
// AccessCapabilities defines allowable access request rules defined in a user's roles.
AccessCapabilities AccessCapabilities `json:"accessCapabilities"`
// ConsumedAccessRequestID is the request ID of the access request from which the assumed role was
// obtained
ConsumedAccessRequestID string `json:"accessRequestId,omitempty"`
}

func getWindowsDesktopLogins(roleSet services.RoleSet) []string {
Expand Down