From 5df1605773097597e463976e7456c701db816563 Mon Sep 17 00:00:00 2001 From: Jackson Argo Date: Fri, 4 Aug 2023 11:34:35 -0400 Subject: [PATCH] add regex for oidc group matching Signed-off-by: Jackson Argo --- connector/oidc/oidc.go | 21 +++++++++++++++++++++ connector/oidc/oidc_test.go | 19 +++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 21129f2227..1495be4b0c 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -8,6 +8,7 @@ import ( "fmt" "net/http" "net/url" + "regexp" "strings" "time" @@ -93,6 +94,7 @@ type Config struct { // ClaimMutations holds all claim mutations options ClaimMutations struct { NewGroupFromClaims []NewGroupFromClaims `json:"newGroupFromClaims"` + FilterGroupClaims FilterGroupClaims `json:"filterGroupClaims"` } `json:"claimModifications"` } @@ -112,6 +114,12 @@ type NewGroupFromClaims struct { Prefix string `json:"prefix"` } +// FilterGroupClaims is a regex filter for to keep only the matching groups. +// This is useful when the groups list is too large to fit within an HTTP header. +type FilterGroupClaims struct { + GroupsFilter string `json:"groupsFilter"` +} + // Domains that don't support basic auth. golang.org/x/oauth2 has an internal // list, but it only matches specific URLs, not top level domains. var brokenAuthHeaderDomains = []string{ @@ -184,6 +192,14 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e c.PromptType = "consent" } + var groupsFilter *regexp.Regexp + if c.ClaimMutations.FilterGroupClaims.GroupsFilter != "" { + groupsFilter, err = regexp.Compile(c.ClaimMutations.FilterGroupClaims.GroupsFilter) + if err != nil { + logger.Warnf("ignoring invalid regex `%s`", c.ClaimMutations.FilterGroupClaims.GroupsFilter) + } + } + clientID := c.ClientID return &oidcConnector{ provider: provider, @@ -214,6 +230,7 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e emailKey: c.ClaimMapping.EmailKey, groupsKey: c.ClaimMapping.GroupsKey, newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, + groupsFilter: groupsFilter, }, nil } @@ -243,6 +260,7 @@ type oidcConnector struct { emailKey string groupsKey string newGroupFromClaims []NewGroupFromClaims + groupsFilter *regexp.Regexp } func (c *oidcConnector) Close() error { @@ -446,6 +464,9 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I if found { for _, v := range vs { if s, ok := v.(string); ok { + if c.groupsFilter != nil && !c.groupsFilter.MatchString(s) { + continue + } groups = append(groups, s) } else { return identity, fmt.Errorf("malformed \"%v\" claim", groupsKey) diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 4bb84a40d6..428a45cc5f 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -62,6 +62,7 @@ func TestHandleCallback(t *testing.T) { expectPreferredUsername string expectedEmailField string token map[string]interface{} + groupsRegex string newGroupFromClaims []NewGroupFromClaims }{ { @@ -362,6 +363,23 @@ func TestHandleCallback(t *testing.T) { "non-string-claim2": 666, }, }, + { + name: "filterGroupClaims", + userIDKey: "", // not configured + userNameKey: "", // not configured + groupsRegex: `^.*\d$`, + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: []string{"group1", "group2"}, + expectedEmailField: "emailvalue", + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "groups": []string{"group1", "group2", "groupA", "groupB"}, + "email": "emailvalue", + "email_verified": true, + }, + }, } for _, tc := range tests { @@ -398,6 +416,7 @@ func TestHandleCallback(t *testing.T) { config.ClaimMapping.EmailKey = tc.emailKey config.ClaimMapping.GroupsKey = tc.groupsKey config.ClaimMutations.NewGroupFromClaims = tc.newGroupFromClaims + config.ClaimMutations.FilterGroupClaims.GroupsFilter = tc.groupsRegex conn, err := newConnector(config) if err != nil {