diff --git a/CHANGELOG.md b/CHANGELOG.md index 787f2b2..b0f7847 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ Versioning](https://semver.org/spec/v2.0.0.html). ## [3.1.1] - 2023-02-02 +### Added +- Added Okta provider support with private key JWT authentication + ### Fixed * The `sts/:name` endpoint should have been seal-wrapped like the corresponding diff --git a/README.md b/README.md index a6e1f04..399747d 100644 --- a/README.md +++ b/README.md @@ -192,6 +192,41 @@ token_url_params map[audience:https://dev-example.us.auth0.com/api/v2/] type Bearer ``` + +#### Client credentials with Private Key JWT Authentication + +Store private key in a file (e.g., private_key.pem), then configure Okta server with private key +``` +export OKTA_PRIVATE_KEY=$(cat private_key.pem) + +$ vault write oauth2/servers/okta-example-jwt \ + provider=okta \ + provider_options=domain=dev-123456.okta.com \ + provider_options=private_key="$OKTA_PRIVATE_KEY" \ + client_id=0oa1234567890abcdef +Success! Data written to: oauth2/servers/okta-example-jwt +``` + +Configure credentials +``` +$ vault write oauth2/self/my-okta-jwt-auth \ + server=okta-example-jwt \ + grant_type=client_credentials \ + scopes=okta.groups.read +Success! Data written to: oauth2/self/my-okta-jwt-auth +``` + +``` +$ vault read oauth2/self/my-okta-jwt-auth +Key Value +--- ----- +access_token eyJraWQiOixxxx +expire_time 2024-11-05T21:41:13.392595-08:00 +scopes [okta.groups.read] +server okta-example-jwt +type Bearer +``` + ## Tips For some operations, you may find that you need to provide a map of data for a diff --git a/go.mod b/go.mod index 6b232f3..ecf2215 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.18 require ( github.com/coreos/go-oidc v2.2.1+incompatible + github.com/go-jose/go-jose/v3 v3.0.3 github.com/golangci/golangci-lint v1.50.1 + github.com/google/uuid v1.3.0 github.com/hashicorp/go-hclog v1.4.0 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-uuid v1.0.3 @@ -149,7 +151,6 @@ require ( github.com/google/go-querystring v1.1.0 // indirect github.com/google/gofuzz v1.1.0 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.3.0 // indirect github.com/googleapis/enterprise-certificate-proxy v0.2.3 // indirect github.com/googleapis/gax-go/v2 v2.7.1 // indirect github.com/googleapis/gnostic v0.5.5 // indirect @@ -349,15 +350,15 @@ require ( go.uber.org/atomic v1.9.0 // indirect go.uber.org/multierr v1.7.0 // indirect go.uber.org/zap v1.19.1 // indirect - golang.org/x/crypto v0.6.0 // indirect + golang.org/x/crypto v0.19.0 // indirect golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e // indirect golang.org/x/exp/typeparams v0.0.0-20220827204233-334a2380cb91 // indirect golang.org/x/mod v0.8.0 // indirect - golang.org/x/net v0.8.0 // indirect + golang.org/x/net v0.10.0 // indirect golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.6.0 // indirect - golang.org/x/term v0.6.0 // indirect - golang.org/x/text v0.8.0 // indirect + golang.org/x/sys v0.17.0 // indirect + golang.org/x/term v0.17.0 // indirect + golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.0.0-20220411224347-583f2d630306 // indirect golang.org/x/tools v0.6.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect diff --git a/go.sum b/go.sum index 4c1b133..0b412ea 100644 --- a/go.sum +++ b/go.sum @@ -389,6 +389,8 @@ github.com/go-errors/errors v1.4.1 h1:IvVlgbzSsaUNudsw5dcXSzF3EWyXTi5XrAdngnuhRy github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= +github.com/go-jose/go-jose/v3 v3.0.3 h1:fFKWeig/irsp7XD2zBxvnmA/XaRWp5V3CBsZXJF7G7k= +github.com/go-jose/go-jose/v3 v3.0.3/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.9.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as= github.com/go-kit/kit v0.10.0/go.mod h1:xUsJbQ/Fp4kEt7AFgCuvyX4a71u8h9jB8tj/ORgOZ7o= @@ -1639,8 +1641,9 @@ golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211108221036-ceb1ce70b4fa/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.6.0 h1:qfktjS5LUO+fFKeJXZ+ikTRijMmljikvG68fpMMruSc= golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1753,8 +1756,8 @@ golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1879,16 +1882,18 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20220526004731-065cf7ba2467/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= -golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/text v0.0.0-20160726164857-2910a502d2bf/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1900,8 +1905,9 @@ golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= -golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/pkg/provider/okta.go b/pkg/provider/okta.go new file mode 100644 index 0000000..40a146d --- /dev/null +++ b/pkg/provider/okta.go @@ -0,0 +1,326 @@ +package provider + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + "github.com/go-jose/go-jose/v3" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/google/uuid" + "github.com/puppetlabs/leg/errmap/pkg/errmark" + "golang.org/x/oauth2" + "golang.org/x/oauth2/clientcredentials" + + "github.com/puppetlabs/vault-plugin-secrets-oauthapp/v3/pkg/oauth2ext/devicecode" +) + +const ( + OptionDomain = "domain" + OptionPrivateKey = "private_key" + OptionScheme = "scheme" + OktaProviderV1 = 1 + + JWTExpirationTime = time.Hour +) + +func init() { + GlobalRegistry.MustRegister("okta", OktaFactory) +} + +type oktaToken struct { + TokenType string `json:"token_type,omitempty"` + ExpiresIn int `json:"expires_in,omitempty"` + AccessToken string `json:"access_token,omitempty"` + Scope string `json:"scope,omitempty"` +} + +type oktaOperations struct { + vsn int + endpointFactory EndpointFactoryFunc + clientID string + clientSecret string + privateKey *rsa.PrivateKey + usePrivateKey bool + httpClient *http.Client +} + +func (o *oktaOperations) createClientAssertion() (string, error) { + if !o.usePrivateKey || o.privateKey == nil { + return "", fmt.Errorf("private key not configured") + } + + signer, err := jose.NewSigner( + jose.SigningKey{Algorithm: jose.RS256, Key: o.privateKey}, + (&jose.SignerOptions{}).WithType("JWT"), + ) + if err != nil { + return "", fmt.Errorf("failed to create signer: %w", err) + } + + now := time.Now() + claims := jwt.Claims{ + Issuer: o.clientID, + Subject: o.clientID, + Audience: jwt.Audience{o.endpointFactory(nil).TokenURL}, + IssuedAt: jwt.NewNumericDate(now), + NotBefore: jwt.NewNumericDate(now), + Expiry: jwt.NewNumericDate(now.Add(JWTExpirationTime)), + ID: uuid.New().String(), + } + + token, err := jwt.Signed(signer).Claims(claims).CompactSerialize() + if err != nil { + return "", fmt.Errorf("failed to create JWT: %w", err) + } + + return token, nil +} + +func (o *oktaOperations) ClientCredentials(ctx context.Context, opts ...ClientCredentialsOption) (*Token, error) { + options := &ClientCredentialsOptions{} + options.ApplyOptions(opts) + + endpoint := o.endpointFactory(options.ProviderOptions) + + if o.usePrivateKey { + // Private Key JWT - Direct HTTP request + clientAssertion, err := o.createClientAssertion() + if err != nil { + return nil, fmt.Errorf("failed to create client assertion: %w", err) + } + + data := url.Values{} + data.Set("grant_type", "client_credentials") + data.Set("scope", strings.Join(options.Scopes, " ")) + data.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") + data.Set("client_assertion", clientAssertion) + + req, err := http.NewRequestWithContext(ctx, "POST", endpoint.TokenURL, strings.NewReader(data.Encode())) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Add("Accept", "application/json") + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var errResp struct { + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + } + if err := json.Unmarshal(body, &errResp); err == nil { + return nil, fmt.Errorf("authentication failed: %s, %s", errResp.Error, errResp.ErrorDescription) + } + return nil, fmt.Errorf("request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var oktaToken oktaToken + if err := json.Unmarshal(body, &oktaToken); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + token := &oauth2.Token{ + AccessToken: oktaToken.AccessToken, + TokenType: oktaToken.TokenType, + Expiry: time.Now().Add(time.Duration(oktaToken.ExpiresIn) * time.Second), + } + + return &Token{ + Token: token, + ProviderVersion: o.vsn, + ProviderOptions: options.ProviderOptions, + }, nil + } + + // Client Secret - Use clientcredentials package + if o.clientSecret == "" { + return nil, errmark.MarkUser(ErrMissingClientSecret) + } + + cc := &clientcredentials.Config{ + ClientID: o.clientID, + ClientSecret: o.clientSecret, + TokenURL: endpoint.TokenURL, + Scopes: options.Scopes, + EndpointParams: options.EndpointParams, + AuthStyle: endpoint.AuthStyle, + } + + tok, err := cc.Token(ctx) + if err != nil { + return nil, err + } + + return &Token{ + Token: tok, + ProviderVersion: o.vsn, + ProviderOptions: options.ProviderOptions, + }, nil +} + +func (o *oktaOperations) AuthCodeURL(state string, opts ...AuthCodeURLOption) (string, bool) { + options := &AuthCodeURLOptions{} + options.ApplyOptions(opts) + + endpoint := o.endpointFactory(options.ProviderOptions) + if endpoint.AuthURL == "" { + return "", false + } + + cfg := &oauth2.Config{ + Endpoint: endpoint.Endpoint, + ClientID: o.clientID, + Scopes: options.Scopes, + RedirectURL: options.RedirectURL, + } + + return cfg.AuthCodeURL(state, options.AuthCodeOptions...), true +} + +func (o *oktaOperations) DeviceCodeAuth(ctx context.Context, opts ...DeviceCodeAuthOption) (*devicecode.Auth, bool, error) { + return nil, false, nil +} + +func (o *oktaOperations) DeviceCodeExchange(ctx context.Context, deviceCode string, opts ...DeviceCodeExchangeOption) (*Token, error) { + return nil, fmt.Errorf("device code flow not supported") +} + +func (o *oktaOperations) AuthCodeExchange(ctx context.Context, code string, opts ...AuthCodeExchangeOption) (*Token, error) { + return nil, fmt.Errorf("auth code exchange flow not supported") +} + +func (o *oktaOperations) TokenExchange(ctx context.Context, t *Token, opts ...TokenExchangeOption) (*Token, error) { + return nil, fmt.Errorf("token exchange flow not supported") +} + +func (o *oktaOperations) RefreshToken(ctx context.Context, t *Token, opts ...RefreshTokenOption) (*Token, error) { + options := &RefreshTokenOptions{} + WithProviderOptions(t.ProviderOptions).ApplyToRefreshTokenOptions(options) + options.ApplyOptions(opts) + + endpoint := o.endpointFactory(options.ProviderOptions) + + cfg := &oauth2.Config{ + Endpoint: endpoint.Endpoint, + ClientID: o.clientID, + ClientSecret: o.clientSecret, + } + + tok, err := cfg.TokenSource(ctx, &oauth2.Token{ + RefreshToken: t.RefreshToken, + }).Token() + if err != nil { + return nil, err + } + + return &Token{ + Token: tok, + ProviderVersion: o.vsn, + ProviderOptions: options.ProviderOptions, + }, nil +} + +type okta struct { + vsn int + endpointFactory EndpointFactoryFunc + privateKey *rsa.PrivateKey + usePrivateKey bool +} + +func (o *okta) Version() int { + return o.vsn +} + +func (o *okta) Public(clientID string) PublicOperations { + return o.Private(clientID, "") +} + +func (o *okta) Private(clientID, clientSecret string) PrivateOperations { + return &oktaOperations{ + vsn: o.vsn, + endpointFactory: o.endpointFactory, + clientID: clientID, + clientSecret: clientSecret, + privateKey: o.privateKey, + usePrivateKey: o.usePrivateKey, + httpClient: &http.Client{}, + } +} + +func OktaFactory(ctx context.Context, vsn int, opts map[string]string) (Provider, error) { + vsn = selectVersion(vsn, OktaProviderV1) + + switch vsn { + case OktaProviderV1: + default: + return nil, ErrNoProviderWithVersion + } + + domain := opts[OptionDomain] + if domain == "" { + return nil, &OptionError{Option: OptionDomain, Cause: fmt.Errorf("domain is required")} + } + + var privateKey *rsa.PrivateKey + usePrivateKey := false + + if keyPEM := opts[OptionPrivateKey]; keyPEM != "" { + block, _ := pem.Decode([]byte(keyPEM)) + if block == nil { + return nil, &OptionError{Option: OptionPrivateKey, Cause: fmt.Errorf("failed to parse PEM block")} + } + + // Only support PKCS8 + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, &OptionError{Option: OptionPrivateKey, Cause: fmt.Errorf("failed to parse PKCS8 private key: %w", err)} + } + + var ok bool + privateKey, ok = key.(*rsa.PrivateKey) + if !ok { + return nil, &OptionError{Option: OptionPrivateKey, Cause: fmt.Errorf("key is not an RSA private key")} + } + + usePrivateKey = true + } + + scheme := opts[OptionScheme] + if scheme == "" { + scheme = "https" + } + + p := &okta{ + vsn: vsn, + endpointFactory: StaticEndpointFactory(Endpoint{ + Endpoint: oauth2.Endpoint{ + AuthURL: fmt.Sprintf("%s://%s/oauth2/v1/authorize", scheme, domain), + TokenURL: fmt.Sprintf("%s://%s/oauth2/v1/token", scheme, domain), + }, + }), + privateKey: privateKey, + usePrivateKey: usePrivateKey, + } + + return p, nil +} diff --git a/pkg/provider/okta_test.go b/pkg/provider/okta_test.go new file mode 100644 index 0000000..46d7af8 --- /dev/null +++ b/pkg/provider/okta_test.go @@ -0,0 +1,151 @@ +package provider + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func generateTestPKCS8Key(t *testing.T) (string, *rsa.PrivateKey) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + privateKeyBytes, err := x509.MarshalPKCS8PrivateKey(privateKey) + require.NoError(t, err) + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: privateKeyBytes, + }) + + return string(privateKeyPEM), privateKey +} +func TestOktaClientCredentials(t *testing.T) { + const ( + testClientID = "test-client" + ) + + tests := []struct { + name string + usePrivateKey bool + clientID string + clientSecret string + wantErr bool + checkRequest func(t *testing.T, r *http.Request) + }{ + { + name: "private key success", + usePrivateKey: true, + clientID: testClientID, + wantErr: false, + checkRequest: func(t *testing.T, r *http.Request) { + err := r.ParseForm() + require.NoError(t, err) + + // Verify common parameters + assert.Equal(t, "POST", r.Method) + assert.Equal(t, "/oauth2/v1/token", r.URL.Path) + assert.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + assert.Equal(t, "client_credentials", r.Form.Get("grant_type")) + + // Verify JWT parameters + assert.Equal(t, "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", r.Form.Get("client_assertion_type")) + assert.NotEmpty(t, r.Form.Get("client_assertion")) + }, + }, + { + name: "missing client secret", + usePrivateKey: false, + clientID: testClientID, + wantErr: true, + checkRequest: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if tt.checkRequest != nil { + tt.checkRequest(t, r) + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "access_token": "test-token", + "token_type": "Bearer", + "expires_in": 3600 + }`)) + })) + defer server.Close() + + opts := map[string]string{ + "domain": strings.TrimPrefix(server.URL, "http://"), + "scheme": "http", + } + + if tt.usePrivateKey { + keyPEM, _ := generateTestPKCS8Key(t) + opts["private_key"] = keyPEM + } + + provider, err := OktaFactory(context.Background(), OktaProviderV1, opts) + require.NoError(t, err) + + ops := provider.Private(tt.clientID, tt.clientSecret) + token, err := ops.ClientCredentials(context.Background()) + + if tt.wantErr { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, token) + assert.Equal(t, "test-token", token.AccessToken) + assert.Equal(t, "Bearer", token.TokenType) + }) + } +} + +func TestOktaUnsupportedFlows(t *testing.T) { + provider, err := OktaFactory(context.Background(), OktaProviderV1, map[string]string{ + "domain": "test.okta.com", + }) + require.NoError(t, err) + + ops := provider.Private("test-client", "test-secret") + + t.Run("device code auth", func(t *testing.T) { + auth, ok, err := ops.DeviceCodeAuth(context.Background()) + assert.Nil(t, auth) + assert.False(t, ok) + assert.Nil(t, err) + }) + + t.Run("device code exchange", func(t *testing.T) { + token, err := ops.DeviceCodeExchange(context.Background(), "test-code") + assert.Nil(t, token) + assert.EqualError(t, err, "device code flow not supported") + }) + + t.Run("auth code exchange", func(t *testing.T) { + token, err := ops.AuthCodeExchange(context.Background(), "test-code") + assert.Nil(t, token) + assert.EqualError(t, err, "auth code exchange flow not supported") + }) + + t.Run("token exchange", func(t *testing.T) { + token, err := ops.TokenExchange(context.Background(), nil) + assert.Nil(t, token) + assert.EqualError(t, err, "token exchange flow not supported") + }) +}