diff --git a/api/types/integration.go b/api/types/integration.go index 45a514ab6e3cb..189ecd959438e 100644 --- a/api/types/integration.go +++ b/api/types/integration.go @@ -29,6 +29,9 @@ import ( const ( // IntegrationSubKindAWSOIDC is an integration with AWS that uses OpenID Connect as an Identity Provider. IntegrationSubKindAWSOIDC = "aws-oidc" + + // IntegrationSubKindAzureOIDC is an integration with Azure that uses OpenID Connect as an Identity Provider. + IntegrationSubKindAzureOIDC = "azure-oidc" ) // Integration specifies is a connection configuration between Teleport and a 3rd party system. @@ -47,6 +50,9 @@ type Integration interface { // SetAWSOIDCIssuerS3URI sets the IssuerS3URI of the AWS OIDC Spec. // Eg, s3://my-bucket/my-prefix SetAWSOIDCIssuerS3URI(string) + + // GetAzureOIDCIntegrationSpec returns the `azure-oidc` spec fields. + GetAzureOIDCIntegrationSpec() *AzureOIDCIntegrationSpecV1 } var _ ResourceWithLabels = (*IntegrationV1)(nil) @@ -72,6 +78,27 @@ func NewIntegrationAWSOIDC(md Metadata, spec *AWSOIDCIntegrationSpecV1) (*Integr return ig, nil } +// NewIntegrationAzureOIDC returns a new `azure-oidc` subkind Integration +func NewIntegrationAzureOIDC(md Metadata, spec *AzureOIDCIntegrationSpecV1) (*IntegrationV1, error) { + ig := &IntegrationV1{ + ResourceHeader: ResourceHeader{ + Metadata: md, + Kind: KindIntegration, + Version: V1, + SubKind: IntegrationSubKindAzureOIDC, + }, + Spec: IntegrationSpecV1{ + SubKindSpec: &IntegrationSpecV1_AzureOIDC{ + AzureOIDC: spec, + }, + }, + } + if err := ig.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return ig, nil +} + // String returns the integration string representation. func (ig *IntegrationV1) String() string { return fmt.Sprintf("IntegrationV1(Name=%v, SubKind=%s, Labels=%v)", @@ -128,6 +155,11 @@ func (s *IntegrationSpecV1) CheckAndSetDefaults() error { if err != nil { return trace.Wrap(err) } + case *IntegrationSpecV1_AzureOIDC: + err := integrationSubKind.Validate() + if err != nil { + return trace.Wrap(err) + } default: return trace.BadParameter("unknown integration subkind: %T", integrationSubKind) } @@ -135,7 +167,7 @@ func (s *IntegrationSpecV1) CheckAndSetDefaults() error { return nil } -// CheckAndSetDefaults validates an agent mesh integration. +// CheckAndSetDefaults validates the configuration for AWS OIDC integration subkind. func (s *IntegrationSpecV1_AWSOIDC) CheckAndSetDefaults() error { if s == nil || s.AWSOIDC == nil { return trace.BadParameter("aws_oidc is required for %q subkind", IntegrationSubKindAWSOIDC) @@ -160,6 +192,21 @@ func (s *IntegrationSpecV1_AWSOIDC) CheckAndSetDefaults() error { return nil } +// Validate validates the configuration for Azure OIDC integration subkind. +func (s *IntegrationSpecV1_AzureOIDC) Validate() error { + if s == nil || s.AzureOIDC == nil { + return trace.BadParameter("azure_oidc is required for %q subkind", IntegrationSubKindAzureOIDC) + } + if s.AzureOIDC.TenantID == "" { + return trace.BadParameter("tenant_id must be set") + } + if s.AzureOIDC.ClientID == "" { + return trace.BadParameter("client_id must be set") + } + + return nil +} + // GetAWSOIDCIntegrationSpec returns the specific spec fields for `aws-oidc` subkind integrations. func (ig *IntegrationV1) GetAWSOIDCIntegrationSpec() *AWSOIDCIntegrationSpecV1 { return ig.Spec.GetAWSOIDC() @@ -198,6 +245,11 @@ func (ig *IntegrationV1) SetAWSOIDCIssuerS3URI(issuerS3URI string) { } } +// GetAzureOIDCIntegrationSpec returns the specific spec fields for `azure-oidc` subkind integrations. +func (ig *IntegrationV1) GetAzureOIDCIntegrationSpec() *AzureOIDCIntegrationSpecV1 { + return ig.Spec.GetAzureOIDC() +} + // Integrations is a list of Integration resources. type Integrations []Integration @@ -247,7 +299,8 @@ func (ig *IntegrationV1) UnmarshalJSON(data []byte) error { d := struct { ResourceHeader `json:""` Spec struct { - AWSOIDC json.RawMessage `json:"aws_oidc"` + AWSOIDC json.RawMessage `json:"aws_oidc"` + AzureOIDC json.RawMessage `json:"azure_oidc"` } `json:"spec"` }{} @@ -270,6 +323,17 @@ func (ig *IntegrationV1) UnmarshalJSON(data []byte) error { integration.Spec.SubKindSpec = subkindSpec + case IntegrationSubKindAzureOIDC: + subkindSpec := &IntegrationSpecV1_AzureOIDC{ + AzureOIDC: &AzureOIDCIntegrationSpecV1{}, + } + + if err := json.Unmarshal(d.Spec.AzureOIDC, subkindSpec.AzureOIDC); err != nil { + return trace.Wrap(err) + } + + integration.Spec.SubKindSpec = subkindSpec + default: return trace.BadParameter("invalid subkind %q", integration.ResourceHeader.SubKind) } @@ -290,7 +354,8 @@ func (ig *IntegrationV1) MarshalJSON() ([]byte, error) { d := struct { ResourceHeader `json:""` Spec struct { - AWSOIDC AWSOIDCIntegrationSpecV1 `json:"aws_oidc"` + AWSOIDC AWSOIDCIntegrationSpecV1 `json:"aws_oidc,omitempty"` + AzureOIDC AzureOIDCIntegrationSpecV1 `json:"azure_oidc,omitempty"` } `json:"spec"` }{} @@ -303,6 +368,12 @@ func (ig *IntegrationV1) MarshalJSON() ([]byte, error) { } d.Spec.AWSOIDC = *ig.GetAWSOIDCIntegrationSpec() + case IntegrationSubKindAzureOIDC: + if ig.GetAzureOIDCIntegrationSpec() == nil { + return nil, trace.BadParameter("missing subkind data for %q subkind", ig.SubKind) + } + + d.Spec.AzureOIDC = *ig.GetAzureOIDCIntegrationSpec() default: return nil, trace.BadParameter("invalid subkind %q", ig.SubKind) } diff --git a/api/types/integration_test.go b/api/types/integration_test.go index 73828198f504b..b64cbc7688315 100644 --- a/api/types/integration_test.go +++ b/api/types/integration_test.go @@ -28,7 +28,7 @@ import ( ) func TestIntegrationJSONMarshalCycle(t *testing.T) { - ig, err := NewIntegrationAWSOIDC( + aws, err := NewIntegrationAWSOIDC( Metadata{Name: "some-integration"}, &AWSOIDCIntegrationSpecV1{ RoleARN: "arn:aws:iam::123456789012:role/DevTeams", @@ -37,14 +37,29 @@ func TestIntegrationJSONMarshalCycle(t *testing.T) { ) require.NoError(t, err) - bs, err := json.Marshal(ig) + azure, err := NewIntegrationAzureOIDC( + Metadata{Name: "some-integration"}, + &AzureOIDCIntegrationSpecV1{ + TenantID: "foo-bar", + ClientID: "baz-quux", + }, + ) require.NoError(t, err) - var ig2 IntegrationV1 - err = json.Unmarshal(bs, &ig2) - require.NoError(t, err) + allIntegrations := []*IntegrationV1{aws, azure} - require.Equal(t, &ig2, ig) + for _, ig := range allIntegrations { + t.Run(ig.SubKind, func(t *testing.T) { + bs, err := json.Marshal(ig) + require.NoError(t, err) + + var ig2 IntegrationV1 + err = json.Unmarshal(bs, &ig2) + require.NoError(t, err) + + require.Equal(t, &ig2, ig) + }) + } } func TestIntegrationCheckAndSetDefaults(t *testing.T) { @@ -59,7 +74,7 @@ func TestIntegrationCheckAndSetDefaults(t *testing.T) { expectedErrorIs func(error) bool }{ { - name: "valid", + name: "aws-oidc: valid", integration: func(name string) (*IntegrationV1, error) { return NewIntegrationAWSOIDC( Metadata{ @@ -104,9 +119,7 @@ func TestIntegrationCheckAndSetDefaults(t *testing.T) { nil, ) }, - expectedErrorIs: func(err error) bool { - return trace.IsBadParameter(err) - }, + expectedErrorIs: trace.IsBadParameter, }, { name: "aws-oidc: error when issuer is not a valid url", @@ -121,9 +134,7 @@ func TestIntegrationCheckAndSetDefaults(t *testing.T) { }, ) }, - expectedErrorIs: func(err error) bool { - return trace.IsBadParameter(err) - }, + expectedErrorIs: trace.IsBadParameter, }, { name: "aws-oidc: issuer is not an s3 url", @@ -138,9 +149,7 @@ func TestIntegrationCheckAndSetDefaults(t *testing.T) { }, ) }, - expectedErrorIs: func(err error) bool { - return trace.IsBadParameter(err) - }, + expectedErrorIs: trace.IsBadParameter, }, { name: "aws-oidc: error when no role is provided", @@ -152,9 +161,71 @@ func TestIntegrationCheckAndSetDefaults(t *testing.T) { &AWSOIDCIntegrationSpecV1{}, ) }, - expectedErrorIs: func(err error) bool { - return trace.IsBadParameter(err) + expectedErrorIs: trace.IsBadParameter, + }, + { + name: "azure-oidc: valid", + integration: func(name string) (*IntegrationV1, error) { + return NewIntegrationAzureOIDC( + Metadata{ + Name: name, + }, + &AzureOIDCIntegrationSpecV1{ + ClientID: "baz-quux", + TenantID: "foo-bar", + }, + ) + }, + expectedIntegration: func(name string) *IntegrationV1 { + return &IntegrationV1{ + ResourceHeader: ResourceHeader{ + Kind: KindIntegration, + SubKind: IntegrationSubKindAzureOIDC, + Version: V1, + Metadata: Metadata{ + Name: name, + Namespace: defaults.Namespace, + }, + }, + Spec: IntegrationSpecV1{ + SubKindSpec: &IntegrationSpecV1_AzureOIDC{ + AzureOIDC: &AzureOIDCIntegrationSpecV1{ + ClientID: "baz-quux", + TenantID: "foo-bar", + }, + }, + }, + } + }, + expectedErrorIs: noErrorFunc, + }, + { + name: "azure-oidc: error when no tenant id is provided", + integration: func(name string) (*IntegrationV1, error) { + return NewIntegrationAzureOIDC( + Metadata{ + Name: name, + }, + &AzureOIDCIntegrationSpecV1{ + ClientID: "baz-quux", + }, + ) + }, + expectedErrorIs: trace.IsBadParameter, + }, + { + name: "azure-oidc: error when no client id is provided", + integration: func(name string) (*IntegrationV1, error) { + return NewIntegrationAzureOIDC( + Metadata{ + Name: name, + }, + &AzureOIDCIntegrationSpecV1{ + TenantID: "foo-bar", + }, + ) }, + expectedErrorIs: trace.IsBadParameter, }, } { t.Run(tt.name, func(t *testing.T) { diff --git a/api/types/plugin.go b/api/types/plugin.go index 9fe95a47fee23..32f490cc005cb 100644 --- a/api/types/plugin.go +++ b/api/types/plugin.go @@ -39,6 +39,7 @@ var AllPluginTypes = []PluginType{ PluginTypePagerDuty, PluginTypeMattermost, PluginTypeDiscord, + PluginTypeEntraID, } const ( @@ -66,6 +67,8 @@ const ( PluginTypeDiscord = "discord" // PluginTypeGitlab indicates the Gitlab access plugin PluginTypeGitlab = "gitlab" + // PluginTypeEntraID indicates the Entra ID sync plugin + PluginTypeEntraID = "entra-id" ) // PluginSubkind represents the type of the plugin, e.g., access request, MDM etc. @@ -294,6 +297,13 @@ func (p *PluginV1) CheckAndSetDefaults() error { if staticCreds == nil { return trace.BadParameter("Gitlab plugin must be used with the static credentials ref type") } + case *PluginSpecV1_EntraId: + if settings.EntraId == nil { + return trace.BadParameter("missing Entra ID settings") + } + if err := settings.EntraId.Validate(); err != nil { + return trace.Wrap(err) + } default: return trace.BadParameter("settings are not set or have an unknown type") } @@ -459,6 +469,8 @@ func (p *PluginV1) GetType() PluginType { return PluginTypeServiceNow case *PluginSpecV1_Gitlab: return PluginTypeGitlab + case *PluginSpecV1_EntraId: + return PluginTypeEntraID default: return PluginTypeUnknown } @@ -609,6 +621,17 @@ func (c *PluginOAuth2AccessTokenCredentials) CheckAndSetDefaults() error { return nil } +func (c *PluginEntraIDSettings) Validate() error { + if c.SyncSettings == nil { + return trace.BadParameter("sync_settings must be set") + } + if len(c.SyncSettings.DefaultOwners) == 0 { + return trace.BadParameter("sync_settings.default_owners must be set") + } + + return nil +} + // GetCode returns the status code func (c PluginStatusV1) GetCode() PluginStatusCode { return c.Code diff --git a/api/types/plugin_test.go b/api/types/plugin_test.go index ad41bcca37471..cf35b3b41bd18 100644 --- a/api/types/plugin_test.go +++ b/api/types/plugin_test.go @@ -668,7 +668,7 @@ func requireBadParameterError(t require.TestingT, err error, args ...any) { require.True(t, trace.IsBadParameter(err), args...) } -func reqireNamedBadParameterError(name string) require.ErrorAssertionFunc { +func requireNamedBadParameterError(name string) require.ErrorAssertionFunc { return func(t require.TestingT, err error, args ...any) { if tt, ok := t.(*testing.T); ok { tt.Helper() @@ -718,15 +718,15 @@ func TestPluginJiraValidation(t *testing.T) { }, { name: "Missing Server URL", mutateSettings: func(s *PluginSpecV1_Jira) { s.Jira.ServerUrl = "" }, - assertErr: reqireNamedBadParameterError("server URL"), + assertErr: requireNamedBadParameterError("server URL"), }, { name: "Missing Project Key", mutateSettings: func(s *PluginSpecV1_Jira) { s.Jira.ProjectKey = "" }, - assertErr: reqireNamedBadParameterError("project key"), + assertErr: requireNamedBadParameterError("project key"), }, { name: "Missing Issue Type", mutateSettings: func(s *PluginSpecV1_Jira) { s.Jira.IssueType = "" }, - assertErr: reqireNamedBadParameterError("issue type"), + assertErr: requireNamedBadParameterError("issue type"), }, { name: "Missing Credentials", mutateCreds: func(c *PluginCredentialsV1) { c.Credentials = nil }, @@ -738,13 +738,13 @@ func TestPluginJiraValidation(t *testing.T) { StaticCredentialsRef. Labels = map[string]string{} }, - assertErr: reqireNamedBadParameterError("labels"), + assertErr: requireNamedBadParameterError("labels"), }, { name: "Invalid Credential Type", mutateCreds: func(c *PluginCredentialsV1) { c.Credentials = &PluginCredentialsV1_Oauth2AccessToken{} }, - assertErr: reqireNamedBadParameterError("static credentials"), + assertErr: requireNamedBadParameterError("static credentials"), }, } @@ -806,7 +806,7 @@ func TestPluginDiscordValidation(t *testing.T) { mutateSettings: func(s *PluginSpecV1_Discord) { s.Discord.RoleToRecipients = map[string]*DiscordChannels{} }, - assertErr: reqireNamedBadParameterError("role_to_recipients"), + assertErr: requireNamedBadParameterError("role_to_recipients"), }, { name: "Missing Default Mapping", mutateSettings: func(s *PluginSpecV1_Discord) { @@ -815,7 +815,7 @@ func TestPluginDiscordValidation(t *testing.T) { ChannelIds: []string{"1234567890"}, } }, - assertErr: reqireNamedBadParameterError("default entry"), + assertErr: requireNamedBadParameterError("default entry"), }, { name: "Missing Credentials", mutateCreds: func(c *PluginCredentialsV1) { c.Credentials = nil }, @@ -825,7 +825,7 @@ func TestPluginDiscordValidation(t *testing.T) { mutateCreds: func(c *PluginCredentialsV1) { c.Credentials = &PluginCredentialsV1_Oauth2AccessToken{} }, - assertErr: reqireNamedBadParameterError("static credentials"), + assertErr: requireNamedBadParameterError("static credentials"), }, } @@ -849,3 +849,62 @@ func TestPluginDiscordValidation(t *testing.T) { }) } } + +func TestPluginEntraIDValidation(t *testing.T) { + validSettings := func() *PluginSpecV1_EntraId { + return &PluginSpecV1_EntraId{ + EntraId: &PluginEntraIDSettings{ + SyncSettings: &PluginEntraIDSyncSettings{ + DefaultOwners: []string{"admin"}, + }, + }, + } + } + testCases := []struct { + name string + mutateSettings func(*PluginSpecV1_EntraId) + assertErr require.ErrorAssertionFunc + }{ + { + name: "valid", + mutateSettings: nil, + assertErr: require.NoError, + }, + { + name: "missing sync settings", + mutateSettings: func(s *PluginSpecV1_EntraId) { + s.EntraId.SyncSettings = nil + }, + assertErr: requireNamedBadParameterError("sync_settings"), + }, + { + name: "missing default owners", + mutateSettings: func(s *PluginSpecV1_EntraId) { + s.EntraId.SyncSettings.DefaultOwners = nil + }, + assertErr: requireNamedBadParameterError("sync_settings.default_owners"), + }, + { + name: "empty default owners", + mutateSettings: func(s *PluginSpecV1_EntraId) { + s.EntraId.SyncSettings.DefaultOwners = []string{} + }, + assertErr: requireNamedBadParameterError("sync_settings.default_owners"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + settings := validSettings() + if tc.mutateSettings != nil { + tc.mutateSettings(settings) + } + + plugin := NewPluginV1( + Metadata{Name: "uut"}, + PluginSpecV1{Settings: settings}, + nil) + tc.assertErr(t, plugin.CheckAndSetDefaults()) + }) + } +} diff --git a/e_imports.go b/e_imports.go index e3ace1fe4c5d3..012ce4063ced0 100644 --- a/e_imports.go +++ b/e_imports.go @@ -83,6 +83,8 @@ import ( _ "github.com/jonboulle/clockwork" _ "github.com/julienschmidt/httprouter" _ "github.com/mailgun/holster/v3/clock" + _ "github.com/microsoft/kiota-authentication-azure-go" + _ "github.com/microsoftgraph/msgraph-sdk-go" _ "github.com/mitchellh/mapstructure" _ "github.com/okta/okta-sdk-golang/v2/okta" _ "github.com/okta/okta-sdk-golang/v2/okta/query" diff --git a/go.mod b/go.mod index e2f63e67a9f0e..9cb194f870621 100644 --- a/go.mod +++ b/go.mod @@ -145,6 +145,8 @@ require ( github.com/mattn/go-sqlite3 v1.14.22 github.com/mdlayher/netlink v1.7.2 github.com/microsoft/go-mssqldb v0.0.0-00010101000000-000000000000 // replaced + github.com/microsoft/kiota-authentication-azure-go v1.0.2 + github.com/microsoftgraph/msgraph-sdk-go v1.37.0 github.com/miekg/pkcs11 v1.1.1 github.com/mitchellh/mapstructure v1.5.0 github.com/moby/term v0.5.0 @@ -285,6 +287,7 @@ require ( github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chai2010/gettext-go v1.0.2 // indirect + github.com/cjlapao/common-go v0.0.39 // indirect github.com/cloudflare/cfssl v1.6.4 // indirect github.com/containerd/console v1.0.4 // indirect github.com/containerd/containerd v1.7.12 // indirect @@ -417,6 +420,13 @@ require ( github.com/mattn/go-localereader v0.0.1 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mdlayher/socket v0.5.1 // indirect + github.com/microsoft/kiota-abstractions-go v1.6.0 // indirect + github.com/microsoft/kiota-http-go v1.3.1 // indirect + github.com/microsoft/kiota-serialization-form-go v1.0.0 // indirect + github.com/microsoft/kiota-serialization-json-go v1.0.7 // indirect + github.com/microsoft/kiota-serialization-multipart-go v1.0.0 // indirect + github.com/microsoft/kiota-serialization-text-go v1.0.0 // indirect + github.com/microsoftgraph/msgraph-sdk-go-core v1.1.0 // indirect github.com/mitchellh/colorstring v0.0.0-20190213212951-d06e56a500db // indirect github.com/mitchellh/copystructure v1.2.0 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect @@ -480,6 +490,7 @@ require ( github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/spf13/viper v1.18.2 // indirect + github.com/std-uritemplate/std-uritemplate/go v0.0.55 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/syndtr/goleveldb v1.0.1-0.20220721030215-126854af5e6d // indirect diff --git a/go.sum b/go.sum index f96ae2a9ff3e7..982a86cfbf14b 100644 --- a/go.sum +++ b/go.sum @@ -980,6 +980,8 @@ github.com/chrismellard/docker-credential-acr-env v0.0.0-20230304212654-82a0ddb2 github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= +github.com/cjlapao/common-go v0.0.39 h1:bAAUrj2B9v0kMzbAOhzjSmiyDy+rd56r2sy7oEiQLlA= +github.com/cjlapao/common-go v0.0.39/go.mod h1:M3dzazLjTjEtZJbbxoA5ZDiGCiHmpwqW9l4UWaddwOA= github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME= github.com/clbanning/mxj/v2 v2.7.0/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= @@ -1852,6 +1854,24 @@ github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/ github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/microsoft/kiota-abstractions-go v1.6.0 h1:qbGBNMU0/o5myKbikCBXJFohVCFrrpx2cO15Rta2WyA= +github.com/microsoft/kiota-abstractions-go v1.6.0/go.mod h1:7YH20ZbRWXGfHSSvdHkdztzgCB9mRdtFx13+hrYIEpo= +github.com/microsoft/kiota-authentication-azure-go v1.0.2 h1:tClGeyFZJ+4Bakf8u0euPM4wqy4ethycdOgx3jyH3pI= +github.com/microsoft/kiota-authentication-azure-go v1.0.2/go.mod h1:aTcti0bUJEcq7kBfQG4Sr4ElvRNuaalXcFEu4iEyQ6M= +github.com/microsoft/kiota-http-go v1.3.1 h1:S+ZDxE7Pc/Z06hbfqpFHkoq5xiC8/7d12iNovcgl+7o= +github.com/microsoft/kiota-http-go v1.3.1/go.mod h1:4QjB+as08swnZXZLx5I+ZHZ8U/tVy7Zu49RNTmWgw48= +github.com/microsoft/kiota-serialization-form-go v1.0.0 h1:UNdrkMnLFqUCccQZerKjblsyVgifS11b3WCx+eFEsAI= +github.com/microsoft/kiota-serialization-form-go v1.0.0/go.mod h1:h4mQOO6KVTNciMF6azi1J9QB19ujSw3ULKcSNyXXOMA= +github.com/microsoft/kiota-serialization-json-go v1.0.7 h1:yMbckSTPrjZdM4EMXgzLZSA3CtDaUBI350u0VoYRz7Y= +github.com/microsoft/kiota-serialization-json-go v1.0.7/go.mod h1:1krrY7DYl3ivPIzl4xTaBpew6akYNa8/Tal8g+kb0cc= +github.com/microsoft/kiota-serialization-multipart-go v1.0.0 h1:3O5sb5Zj+moLBiJympbXNaeV07K0d46IfuEd5v9+pBs= +github.com/microsoft/kiota-serialization-multipart-go v1.0.0/go.mod h1:yauLeBTpANk4L03XD985akNysG24SnRJGaveZf+p4so= +github.com/microsoft/kiota-serialization-text-go v1.0.0 h1:XOaRhAXy+g8ZVpcq7x7a0jlETWnWrEum0RhmbYrTFnA= +github.com/microsoft/kiota-serialization-text-go v1.0.0/go.mod h1:sM1/C6ecnQ7IquQOGUrUldaO5wj+9+v7G2W3sQ3fy6M= +github.com/microsoftgraph/msgraph-sdk-go v1.37.0 h1:wD62FzIBu4gVg70ikAm7D45tFqvDKo+K6aJ+zeNWlAE= +github.com/microsoftgraph/msgraph-sdk-go v1.37.0/go.mod h1:xYBUc+4LGRjYRyTF4CLiKAqb981AhlSfvRX11is25q4= +github.com/microsoftgraph/msgraph-sdk-go-core v1.1.0 h1:NB7c/n4Knj+TLaLfjsahhSqoUqoN/CtyNB0XIe/nJnM= +github.com/microsoftgraph/msgraph-sdk-go-core v1.1.0/go.mod h1:M3w/5IFJ1u/DpwOyjsjNSVEA43y1rLOeX58suyfBhGk= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKjuso= github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= @@ -2190,6 +2210,8 @@ github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= github.com/spiffe/go-spiffe/v2 v2.2.0 h1:9Vf06UsvsDbLYK/zJ4sYsIsHmMFknUD+feA7IYoWMQY= github.com/spiffe/go-spiffe/v2 v2.2.0/go.mod h1:Urzb779b3+IwDJD2ZbN8fVl3Aa8G4N/PiUe6iXC0XxU= +github.com/std-uritemplate/std-uritemplate/go v0.0.55 h1:muSH037g97K7U2f94G9LUuE8tZlJsoSSrPsO9V281WY= +github.com/std-uritemplate/std-uritemplate/go v0.0.55/go.mod h1:rG/bqh/ThY4xE5de7Rap3vaDkYUT76B0GPJ0loYeTTc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= diff --git a/integration/appaccess/jwt.go b/integration/appaccess/jwt.go index 4359dfbd80df9..b9715a3369a24 100644 --- a/integration/appaccess/jwt.go +++ b/integration/appaccess/jwt.go @@ -39,7 +39,7 @@ func verifyJWT(t *testing.T, pack *Pack, token, appURI string) { var jwks web.JWKSResponse err = json.Unmarshal([]byte(body), &jwks) require.NoError(t, err) - require.Len(t, jwks.Keys, 1) + require.Len(t, jwks.Keys, 2) // For backwards compatibility the same key is included twice in JWKS with different key IDs. publicKey, err := jwt.UnmarshalJWK(jwks.Keys[0]) require.NoError(t, err) diff --git a/lib/integrations/azureoidc/token_generator.go b/lib/integrations/azureoidc/token_generator.go new file mode 100644 index 0000000000000..e31cf2e14f5da --- /dev/null +++ b/lib/integrations/azureoidc/token_generator.go @@ -0,0 +1,96 @@ +// Teleport +// Copyright (C) 2024 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package azureoidc + +import ( + "context" + "crypto" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/jwt" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils/oidc" +) + +// azureDefaultJWTAudience is the default audience used by Azure +// when setting up an enterprise application. +const azureDefaultJWTAudience = "api://AzureADTokenExchange" + +// KeyStoreManager defines methods to get signers using the server's keystore. +type KeyStoreManager interface { + // GetJWTSigner selects a usable JWT keypair from the given keySet and returns a [crypto.Signer]. + GetJWTSigner(ctx context.Context, ca types.CertAuthority) (crypto.Signer, error) +} + +// Cache is the subset of the cached resources that the AWS OIDC Token Generation queries. +type Cache interface { + // GetCertAuthority returns cert authority by id + GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) + + // GetClusterName returns local cluster name of the current auth server + GetClusterName(...services.MarshalOption) (types.ClusterName, error) + + // GetProxies returns a list of registered proxies. + GetProxies() ([]types.Server, error) +} + +// GenerateEntraOIDCToken returns a JWT suitable for OIDC authentication to MS Graph API. +func GenerateEntraOIDCToken(ctx context.Context, cache Cache, manager KeyStoreManager, clock clockwork.Clock) (string, error) { + clusterName, err := cache.GetClusterName() + if err != nil { + return "", trace.Wrap(err) + } + + ca, err := cache.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.OIDCIdPCA, + DomainName: clusterName.GetClusterName(), + }, true /*loadKeys*/) + if err != nil { + return "", trace.Wrap(err) + } + + issuer, err := oidc.IssuerForCluster(ctx, cache) + if err != nil { + return "", trace.Wrap(err) + } + + signer, err := manager.GetJWTSigner(ctx, ca) + if err != nil { + return "", trace.Wrap(err) + } + + privateKey, err := services.GetJWTSigner(signer, ca.GetClusterName(), clock) + if err != nil { + return "", trace.Wrap(err) + } + + token, err := privateKey.SignEntraOIDC(jwt.SignParams{ + Audience: azureDefaultJWTAudience, + Subject: "teleport-azure", // TODO(justinas): consider moving this to a constant or a field in the integration settings + Issuer: issuer, + Expires: clock.Now().Add(time.Hour), + }) + if err != nil { + return "", trace.Wrap(err) + } + + return token, nil +} diff --git a/lib/jwt/jwk.go b/lib/jwt/jwk.go index c851ba5725d11..68d575b86eb66 100644 --- a/lib/jwt/jwk.go +++ b/lib/jwt/jwk.go @@ -21,6 +21,8 @@ package jwt import ( "crypto" "crypto/rsa" + "crypto/sha256" + "crypto/x509" "encoding/base64" "math/big" @@ -50,6 +52,12 @@ type JWK struct { KeyID string `json:"kid"` } +// KeyID returns a key id derived from the public key. +func KeyID(pub *rsa.PublicKey) string { + hash := sha256.Sum256(x509.MarshalPKCS1PublicKey(pub)) + return base64.RawURLEncoding.EncodeToString(hash[:]) +} + // MarshalJWK will marshal a supported public key into JWK format. func MarshalJWK(bytes []byte) (JWK, error) { // Parse the public key and validate type. @@ -69,7 +77,7 @@ func MarshalJWK(bytes []byte) (JWK, error) { N: base64.RawURLEncoding.EncodeToString(publicKey.N.Bytes()), E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(publicKey.E)).Bytes()), Use: defaults.JWTUse, - KeyID: "", + KeyID: KeyID(publicKey), }, nil } diff --git a/lib/jwt/jwk_test.go b/lib/jwt/jwk_test.go index 6cfc91d2102d5..d71881d69d637 100644 --- a/lib/jwt/jwk_test.go +++ b/lib/jwt/jwk_test.go @@ -19,6 +19,11 @@ package jwt import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "math/big" "testing" "github.com/stretchr/testify/require" @@ -34,3 +39,47 @@ func TestMarshalJWK(t *testing.T) { // Required for integrating with AWS OpenID Connect Identity Provider. require.Equal(t, "sig", jwk.Use) } + +func TestKeyIDHasConsistentOutputForAnInput(t *testing.T) { + t.Parallel() + + privateKey, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + publicKey := privateKey.Public().(*rsa.PublicKey) + id1 := KeyID(publicKey) + id2 := KeyID(publicKey) + require.NotEmpty(t, id1) + require.Equal(t, id1, id2) + + expectedLength := base64.RawURLEncoding.EncodedLen(sha256.Size) + require.Len(t, id1, expectedLength, "expected key id to always be %d characters long", expectedLength) +} + +func TestKeyIDHasDistinctOutputForDifferingInputs(t *testing.T) { + t.Parallel() + + privateKey1, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + privateKey2, err := rsa.GenerateKey(rand.Reader, 1024) + require.NoError(t, err) + publicKey1 := privateKey1.Public().(*rsa.PublicKey) + publicKey2 := privateKey2.Public().(*rsa.PublicKey) + id1 := KeyID(publicKey1) + id2 := KeyID(publicKey2) + require.NotEmpty(t, id1) + require.NotEmpty(t, id2) + require.NotEqual(t, id1, id2) +} + +// TestKeyIDCompatibility ensures we do not introduce a change in the KeyID algorithm for existing keys. +// It does so by ensuring that a pre-generated public key results in the expected value. +func TestKeyIDCompatibility(t *testing.T) { + n, ok := new(big.Int). + SetString("10804584566601725083798733714540307814537881454603593919227265169397611763416631197061041949793088023127406259586903197568870611092333639226643589004457719", 10) + require.True(t, ok, "failed to create a bigint") + publicKey := &rsa.PublicKey{ + E: 65537, + N: n, + } + require.Equal(t, "GDLHLDvPUYmNLVU3WgshDX7bAw8xEmML8ypeE9KRAEQ", KeyID(publicKey)) +} diff --git a/lib/jwt/jwt.go b/lib/jwt/jwt.go index 4e24a8541c0a3..083d3b4139107 100644 --- a/lib/jwt/jwt.go +++ b/lib/jwt/jwt.go @@ -145,12 +145,13 @@ func (p *SignParams) Check() error { } // sign will return a signed JWT with the passed in claims embedded within. -func (k *Key) sign(claims any) (string, error) { - return k.signAny(claims) +// `opts`, when not nil, specifies additional signing options, such as additional JWT headers. +func (k *Key) sign(claims any, opts *jose.SignerOptions) (string, error) { + return k.signAny(claims, opts) } // signAny will return a signed JWT with the passed in claims embedded within; unlike sign it allows more flexibility in the claim data. -func (k *Key) signAny(claims any) (string, error) { +func (k *Key) signAny(claims any, opts *jose.SignerOptions) (string, error) { if k.config.PrivateKey == nil { return "", trace.BadParameter("can not sign token with non-signing key") } @@ -167,7 +168,12 @@ func (k *Key) signAny(claims any) (string, error) { Algorithm: k.config.Algorithm, Key: signer, } - sig, err := jose.NewSigner(signingKey, (&jose.SignerOptions{}).WithType("JWT")) + + if opts == nil { + opts = &jose.SignerOptions{} + } + opts = opts.WithType("JWT") + sig, err := jose.NewSigner(signingKey, opts) if err != nil { return "", trace.Wrap(err) } @@ -199,7 +205,7 @@ func (k *Key) Sign(p SignParams) (string, error) { Traits: p.Traits, } - return k.sign(claims) + return k.sign(claims, nil) } // awsOIDCCustomClaims defines the require claims for the JWT token used in AWS OIDC Integration. @@ -216,7 +222,7 @@ type awsOIDCCustomClaims struct { // - Issuer: stored as Issuer (iss) claim // - Subject: stored as Subject (sub) claim // - Audience: stored as Audience (aud) claim -// - Expiries: stored as Expiry (exp) claim +// - Expires: stored as Expiry (exp) claim func (k *Key) SignAWSOIDC(p SignParams) (string, error) { // Sign the claims and create a JWT token. claims := awsOIDCCustomClaims{ @@ -232,7 +238,42 @@ func (k *Key) SignAWSOIDC(p SignParams) (string, error) { }, } - return k.sign(claims) + // AWS does not require `kid` claim in the JWT per se, + // but it seems to (NB: educated guess) require it if JWKS has multiple JWK-s with different `kid`-s. + opts := (&jose.SignerOptions{}). + WithHeader(jose.HeaderKey("kid"), "") + + return k.sign(claims, opts) +} + +// SignEntraOIDC signs a JWT for the Entra ID Integration. +// Required Params: +// - Issuer: stored as Issuer (iss) claim +// - Subject: stored as Subject (sub) claim +// - Audience: stored as Audience (aud) claim +// - Expires: stored as Expiry (exp) claim +func (k *Key) SignEntraOIDC(p SignParams) (string, error) { + // Sign the claims and create a JWT token. + claims := jwt.Claims{ + Issuer: p.Issuer, + Subject: p.Subject, + Audience: jwt.Audience{p.Audience}, + ID: uuid.NewString(), + NotBefore: jwt.NewNumericDate(k.config.Clock.Now().Add(-10 * time.Second)), + Expiry: jwt.NewNumericDate(p.Expires), + IssuedAt: jwt.NewNumericDate(k.config.Clock.Now().Add(-10 * time.Second)), + } + + // Azure expect a `kid` header to be present and non-empty, + // unlike e.g. AWS which accepts an empty `kid` string value. + publicKey, ok := k.config.PublicKey.(*rsa.PublicKey) + if !ok { + return "", trace.BadParameter("expected an RSA public key") + } + kid := KeyID(publicKey) + opts := (&jose.SignerOptions{}). + WithHeader(jose.HeaderKey("kid"), kid) + return k.sign(claims, opts) } func (k *Key) SignSnowflake(p SignParams, issuer string) (string, error) { @@ -247,7 +288,7 @@ func (k *Key) SignSnowflake(p SignParams, issuer string) (string, error) { }, } - return k.sign(claims) + return k.sign(claims, nil) } // AzureTokenClaims represent a minimal set of claims that will be encoded as JWT in Azure access token and passed back to az CLI. @@ -260,7 +301,7 @@ type AzureTokenClaims struct { // SignAzureToken signs AzureTokenClaims func (k *Key) SignAzureToken(claims AzureTokenClaims) (string, error) { - return k.signAny(claims) + return k.signAny(claims, nil) } type PROXYSignParams struct { @@ -284,7 +325,7 @@ func (k *Key) SignPROXYJWT(p PROXYSignParams) (string, error) { }, } - return k.sign(claims) + return k.sign(claims, nil) } // VerifyParams are the parameters needed to pass the token and data needed to verify. diff --git a/lib/web/oidcidp.go b/lib/web/oidcidp.go index d3b34639908e3..028b61d0e349d 100644 --- a/lib/web/oidcidp.go +++ b/lib/web/oidcidp.go @@ -79,6 +79,11 @@ func (h *Handler) jwks(ctx context.Context, caType types.CertAuthType) (*JWKSRes return nil, trace.Wrap(err) } resp.Keys = append(resp.Keys, jwk) + + // Return an additional copy of the same JWK + // with KeyID set to the empty string for compatibility. + jwk.KeyID = "" + resp.Keys = append(resp.Keys, jwk) } return &resp, nil } diff --git a/lib/web/oidcidp_test.go b/lib/web/oidcidp_test.go index c399ed63488ae..ac346de37a8e8 100644 --- a/lib/web/oidcidp_test.go +++ b/lib/web/oidcidp_test.go @@ -71,11 +71,19 @@ func TestOIDCIdPPublicEndpoints(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, jwksKeys.Keys) - key := jwksKeys.Keys[0] - require.Equal(t, "sig", key.Use) - require.Equal(t, "RSA", key.KeyType) - require.Equal(t, "RS256", key.Alg) - require.NotNil(t, key.KeyID) // AWS requires this to be present (even if empty string). + require.Len(t, jwksKeys.Keys, 2) + + // Expect the same key twice, once with a synthesized Key ID, and once with an empty Key ID for compatibility. + key1 := jwksKeys.Keys[0] + key2 := jwksKeys.Keys[1] + require.Equal(t, "sig", key1.Use) + require.Equal(t, "RSA", key1.KeyType) + require.Equal(t, "RS256", key1.Alg) + require.Equal(t, key1.Use, key2.Use) + require.Equal(t, key1.KeyType, key2.KeyType) + require.Equal(t, key1.Alg, key2.Alg) + require.NotEmpty(t, *key1.KeyID) + require.Equal(t, "", *key2.KeyID) } func TestThumbprint(t *testing.T) {