Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions api/types/oidc_external.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
Copyright 2022 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package types

import (
"encoding/json"
"time"

"github.com/gravitational/trace"
)

// OIDCClaims is a redefinition of jose.Claims with additional methods, required for serialization to/from protobuf.
// With those we can reference it with an option like so: `(gogoproto.customtype) = "OIDCClaims"`
type OIDCClaims map[string]interface{}

// Size returns size of the object when marshaled
func (a *OIDCClaims) Size() int {
bytes, err := json.Marshal(a)
if err != nil {
return 0
}
return len(bytes)
}

// Unmarshal the object from provided buffer.
func (a *OIDCClaims) Unmarshal(bytes []byte) error {
Comment thread
Tener marked this conversation as resolved.
return trace.Wrap(json.Unmarshal(bytes, a))
}

// MarshalTo marshals the object to sized buffer
func (a *OIDCClaims) MarshalTo(bytes []byte) (int, error) {
out, err := json.Marshal(a)
if err != nil {
return 0, trace.Wrap(err)
}

if len(out) > cap(bytes) {
return 0, trace.BadParameter("capacity too low: %v, need %v", cap(bytes), len(out))
}

copy(bytes, out)

return len(out), nil
}

// OIDCIdentity is a redefinition of oidc.Identity with additional methods, required for serialization to/from protobuf.
// With those we can reference it with an option like so: `(gogoproto.customtype) = "OIDCIdentity"`
type OIDCIdentity struct {
// ID is populated from "subject" claim.
ID string
// Name of user. Empty in current version of library.
Name string
// Email is populated from "email" claim.
Email string
// ExpiresAt populated from "exp" claim, represents expiry time.
ExpiresAt time.Time
}

// Size returns size of the object when marshaled
func (a *OIDCIdentity) Size() int {
bytes, err := json.Marshal(a)
if err != nil {
return 0
}
return len(bytes)
}

// Unmarshal the object from provided buffer.
func (a *OIDCIdentity) Unmarshal(bytes []byte) error {
return trace.Wrap(json.Unmarshal(bytes, a))
}

// MarshalTo marshals the object to sized buffer
func (a *OIDCIdentity) MarshalTo(bytes []byte) (int, error) {
out, err := json.Marshal(a)
if err != nil {
return 0, trace.Wrap(err)
}

if len(out) > cap(bytes) {
return 0, trace.BadParameter("capacity too low: %v, need %v", cap(bytes), len(out))
}

copy(bytes, out)

return len(out), nil
}
57 changes: 57 additions & 0 deletions api/types/oidc_external_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2022 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package types

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestOIDCClaimsRoundTrip(t *testing.T) {
tests := []struct {
name string
src OIDCClaims
}{
{
name: "empty",
src: OIDCClaims{},
},
{
name: "full",
src: OIDCClaims(map[string]interface{}{
"email_verified": true,
"groups": []interface{}{"everyone", "idp-admin", "idp-dev"},
"email": "superuser@example.com",
"sub": "00001234abcd",
"exp": 1652091713.0,
}),
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
buf := make([]byte, tt.src.Size())
count, err := tt.src.MarshalTo(buf)
require.NoError(t, err)
require.Equal(t, tt.src.Size(), count)

dst := &OIDCClaims{}
err = dst.Unmarshal(buf)
require.NoError(t, err)
require.Equal(t, &tt.src, dst)
})
}
}
1,992 changes: 1,164 additions & 828 deletions api/types/types.pb.go

Large diffs are not rendered by default.

55 changes: 46 additions & 9 deletions api/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -2740,41 +2740,78 @@ message SSODiagnosticInfo {
// CreateUserParams represents the user creation parameters as called during SSO login flow.
CreateUserParams CreateUserParams = 4 [ (gogoproto.jsontag) = "create_user_params,omitempty" ];

// SAMLAttributesToRoles represents mapping from attributes to roles, as used during SAML login
// flow.
repeated AttributeMapping SAMLAttributesToRoles = 5
[ (gogoproto.nullable) = false, (gogoproto.jsontag) = "attributes_to_roles,omitempty" ];
// SAMLAttributesToRoles represents mapping from attributes to roles, as used during SAML SSO
// login flow.
repeated AttributeMapping SAMLAttributesToRoles = 10 [
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "saml_attributes_to_roles,omitempty"
];

// SAMLAttributesToRolesWarnings contains warnings produced during the process of mapping the
// SAML attributes to roles.
SSOWarnings SAMLAttributesToRolesWarnings = 6
SSOWarnings SAMLAttributesToRolesWarnings = 11
[ (gogoproto.jsontag) = "saml_attributes_to_roles_warnings,omitempty" ];

// SAMLAttributeStatements represents SAML attribute statements.
wrappers.LabelValues SAMLAttributeStatements = 7 [
wrappers.LabelValues SAMLAttributeStatements = 12 [
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "saml_attribute_statements,omitempty",
(gogoproto.customtype) = "github.com/gravitational/teleport/api/types/wrappers.Traits"
];

// SAMLAssertionInfo represents raw SAML assertion info as returned by IdP during SAML flow.
bytes SAMLAssertionInfo = 8 [
bytes SAMLAssertionInfo = 13 [
(gogoproto.jsontag) = "saml_assertion_info,omitempty",
(gogoproto.customtype) = "AssertionInfo"
];

// SAMLTraitsFromAssertions represents traits translated from SAML assertions.
wrappers.LabelValues SAMLTraitsFromAssertions = 9 [
wrappers.LabelValues SAMLTraitsFromAssertions = 14 [
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "saml_traits_from_assertions,omitempty",
(gogoproto.customtype) = "github.com/gravitational/teleport/api/types/wrappers.Traits"
];

// SAMLConnectorTraitMapping represents connector-specific trait mapping.
repeated TraitMapping SAMLConnectorTraitMapping = 10 [
repeated TraitMapping SAMLConnectorTraitMapping = 15 [
Comment thread
Tener marked this conversation as resolved.
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "saml_connector_trait_mapping,omitempty"
];

// OIDCClaimsToRoles specifies a mapping from claims (traits) to teleport roles.
repeated ClaimMapping OIDCClaimsToRoles = 20
[ (gogoproto.nullable) = false, (gogoproto.jsontag) = "oidc_claims_to_roles,omitempty" ];

// OIDCClaimsToRolesWarnings contains warnings produced during the process of mapping the
// OIDC claims to roles.
SSOWarnings OIDCClaimsToRolesWarnings = 21
[ (gogoproto.jsontag) = "oidc_claims_to_roles_warnings,omitempty" ];

// OIDCClaims represents OIDC claims.
bytes OIDCClaims = 22 [
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "oidc_claims,omitempty",
(gogoproto.customtype) = "OIDCClaims"
];

// OIDCIdentity represents mapped OIDC Identity.
bytes OIDCIdentity = 23 [
(gogoproto.jsontag) = "oidc_identity,omitempty",
(gogoproto.customtype) = "OIDCIdentity"
];

// OIDCTraitsFromClaims represents traits translated from OIDC claims.
wrappers.LabelValues OIDCTraitsFromClaims = 24 [
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "oidc_traits_from_claims,omitempty",
(gogoproto.customtype) = "github.com/gravitational/teleport/api/types/wrappers.Traits"
];

// OIDCConnectorTraitMapping represents connector-specific trait mapping.
repeated TraitMapping OIDCConnectorTraitMapping = 25 [
(gogoproto.nullable) = false,
Comment thread
Tener marked this conversation as resolved.
(gogoproto.jsontag) = "oidc_connector_trait_mapping,omitempty"
];
}

// TeamMapping represents a single team membership mapping.
Expand Down
3 changes: 3 additions & 0 deletions constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,9 @@ const (
CertExtensionGeneration = "generation"
)

// Note: when adding new providers to this list, consider updating the help message for --provider flag
// for `tctl sso configure oidc` and `tctl sso configure saml` commands
// as well as docs at https://goteleport.com/docs/enterprise/sso/#provider-specific-workarounds
const (
// NetIQ is an identity provider.
NetIQ = "netiq"
Expand Down
11 changes: 10 additions & 1 deletion lib/auth/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ func NewAPIServer(config *APIConfig) (http.Handler, error) {
srv.DELETE("/:version/oidc/connectors/:id", srv.withAuth(srv.deleteOIDCConnector))
srv.POST("/:version/oidc/requests/create", srv.withAuth(srv.createOIDCAuthRequest))
srv.POST("/:version/oidc/requests/validate", srv.withAuth(srv.validateOIDCAuthCallback))
srv.GET("/:version/oidc/requests/get/:id", srv.withAuth(srv.getOIDCAuthRequest))

// SAML handlers
srv.POST("/:version/saml/connectors", srv.withAuth(srv.createSAMLConnector))
Expand Down Expand Up @@ -1243,6 +1244,14 @@ func (s *APIServer) createOIDCAuthRequest(auth ClientI, w http.ResponseWriter, r
return response, nil
}

func (s *APIServer) getOIDCAuthRequest(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) {
request, err := auth.GetOIDCAuthRequest(r.Context(), p.ByName("id"))
if err != nil {
return nil, trace.Wrap(err)
}
return request, nil
}

type validateOIDCAuthCallbackReq struct {
Query url.Values `json:"query"`
}
Expand Down Expand Up @@ -1272,7 +1281,7 @@ func (s *APIServer) validateOIDCAuthCallback(auth ClientI, w http.ResponseWriter
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}
response, err := auth.ValidateOIDCAuthCallback(req.Query)
response, err := auth.ValidateOIDCAuthCallback(r.Context(), req.Query)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
7 changes: 6 additions & 1 deletion lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"sync"
"time"

"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/google/uuid"
Expand Down Expand Up @@ -200,7 +201,8 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
WindowsDesktops: cfg.WindowsDesktops,
SessionTrackerService: cfg.SessionTrackerService,
},
keyStore: keyStore,
keyStore: keyStore,
getClaimsFun: getClaims,
}
for _, o := range opts {
o(&as)
Expand Down Expand Up @@ -360,6 +362,9 @@ type Server struct {

// lockWatcher is a lock watcher, used to verify cert generation requests.
lockWatcher *services.LockWatcher

// getClaimsFun is used in tests for overriding the implementation of getClaims method used in OIDC.
getClaimsFun func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error)
}

// SetCache sets cache used by auth server
Expand Down
8 changes: 7 additions & 1 deletion lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,13 @@ func TestOIDCConnectorCRUDEventsEmitted(t *testing.T) {

ctx := context.Background()
// test oidc create event
oidc, err := types.NewOIDCConnector("test", types.OIDCConnectorSpecV3{ClientID: "a"})
oidc, err := types.NewOIDCConnector("test", types.OIDCConnectorSpecV3{ClientID: "a", ClaimsToRoles: []types.ClaimMapping{
{
Claim: "dummy",
Value: "dummy",
Roles: []string{"dummy"},
},
}})
require.NoError(t, err)
err = s.a.UpsertOIDCConnector(ctx, oidc)
require.NoError(t, err)
Expand Down
23 changes: 19 additions & 4 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -2428,19 +2428,34 @@ func (a *ServerWithRoles) CreateOIDCAuthRequest(req services.OIDCAuthRequest) (*
return nil, trace.Wrap(err)
}

// require additional permissions for executing SSO test flow.
if req.SSOTestFlow {
if err := a.authConnectorAction(apidefaults.Namespace, types.KindOIDC, types.VerbCreate); err != nil {
return nil, trace.Wrap(err)
}
}

oidcReq, err := a.authServer.CreateOIDCAuthRequest(req)
if err != nil {
// TODO(Tener): Update `testFlow` flag once OIDC SSO starts supporting test flows.
emitSSOLoginFailureEvent(a.authServer.closeCtx, a.authServer.emitter, events.LoginMethodOIDC, err, false)
emitSSOLoginFailureEvent(a.authServer.closeCtx, a.authServer.emitter, events.LoginMethodOIDC, err, req.SSOTestFlow)
return nil, trace.Wrap(err)
}

return oidcReq, nil
}

func (a *ServerWithRoles) ValidateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse, error) {
// GetOIDCAuthRequest returns OIDC auth request if found.
func (a *ServerWithRoles) GetOIDCAuthRequest(ctx context.Context, id string) (*services.OIDCAuthRequest, error) {
if err := a.action(apidefaults.Namespace, types.KindOIDCRequest, types.VerbRead); err != nil {
return nil, trace.Wrap(err)
}

return a.authServer.GetOIDCAuthRequest(ctx, id)
}

func (a *ServerWithRoles) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) {
// auth callback is it's own authz, no need to check extra permissions
return a.authServer.ValidateOIDCAuthCallback(q)
return a.authServer.ValidateOIDCAuthCallback(ctx, q)
}

func (a *ServerWithRoles) DeleteOIDCConnector(ctx context.Context, connectorID string) error {
Expand Down
Loading