Skip to content

Commit

Permalink
fix: return proper error when the grant request cannot be parsed (#3558)
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl authored Jun 29, 2023
1 parent 93ebaee commit 26f2d34
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 25 deletions.
8 changes: 7 additions & 1 deletion oauth2/trust/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"time"

"github.com/ory/fosite"
"github.com/ory/x/pagination/tokenpagination"

"github.com/ory/hydra/v2/x"
Expand Down Expand Up @@ -110,7 +111,12 @@ func (h *Handler) trustOAuth2JwtGrantIssuer(w http.ResponseWriter, r *http.Reque
var grantRequest createGrantRequest

if err := json.NewDecoder(r.Body).Decode(&grantRequest); err != nil {
h.registry.Writer().WriteError(w, r, errorsx.WithStack(err))
h.registry.Writer().WriteError(w, r,
errorsx.WithStack(&fosite.RFC6749Error{
ErrorField: "error",
DescriptionField: err.Error(),
CodeField: http.StatusBadRequest,
}))
return
}

Expand Down
20 changes: 20 additions & 0 deletions oauth2/trust/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
"crypto/rand"
"crypto/rsa"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/tidwall/gjson"
"gopkg.in/square/go-jose.v2"

"github.com/ory/x/pointerx"
Expand Down Expand Up @@ -154,6 +156,24 @@ func (s *HandlerTestSuite) TestGrantCanNotBeCreatedWithSubjectAndAnySubject() {
s.Require().Error(err, "expected error, because a grant with a subject and allow_any_subject cannot be created")
}

func (s *HandlerTestSuite) TestGrantCanNotBeCreatedWithUnknownJWK() {
createRequestParams := hydra.TrustOAuth2JwtGrantIssuer{
AllowAnySubject: pointerx.Ptr(true),
ExpiresAt: time.Now().Add(1 * time.Hour),
Issuer: "ory",
Jwk: hydra.JsonWebKey{
Alg: "unknown",
},
Scope: []string{"openid", "offline", "profile"},
}

_, res, err := s.hydraClient.OAuth2Api.TrustOAuth2JwtGrantIssuer(context.Background()).TrustOAuth2JwtGrantIssuer(createRequestParams).Execute()
s.Assert().Equal(http.StatusBadRequest, res.StatusCode)
body, _ := io.ReadAll(res.Body)
s.Contains(gjson.GetBytes(body, "error_description").String(), "unknown json web key type")
s.Require().Error(err, "expected error, because the key type was unknown")
}

func (s *HandlerTestSuite) TestGrantCanNotBeCreatedWithMissingFields() {
createRequestParams := s.newCreateJwtBearerGrantParams(
"",
Expand Down
49 changes: 25 additions & 24 deletions persistence/sql/persister_grant_jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@ import (
"gopkg.in/square/go-jose.v2"

"github.com/ory/hydra/v2/oauth2/trust"
"github.com/ory/x/otelx"
"github.com/ory/x/stringsx"

"github.com/ory/x/sqlcon"
)

var _ trust.GrantManager = &Persister{}

func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jose.JSONWebKey) error {
func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jose.JSONWebKey) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CreateGrant")
defer span.End()
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
// add key, if it doesn't exist
Expand All @@ -42,9 +43,9 @@ func (p *Persister) CreateGrant(ctx context.Context, g trust.Grant, publicKey jo
})
}

func (p *Persister) GetConcreteGrant(ctx context.Context, id string) (trust.Grant, error) {
func (p *Persister) GetConcreteGrant(ctx context.Context, id string) (_ trust.Grant, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetConcreteGrant")
defer span.End()
defer otelx.End(span, &err)

var data trust.SQLData
if err := p.QueryWithNetwork(ctx).Where("id = ?", id).First(&data); err != nil {
Expand All @@ -54,9 +55,9 @@ func (p *Persister) GetConcreteGrant(ctx context.Context, id string) (trust.Gran
return p.jwtGrantFromSQlData(data), nil
}

func (p *Persister) DeleteGrant(ctx context.Context, id string) error {
func (p *Persister) DeleteGrant(ctx context.Context, id string) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.DeleteGrant")
defer span.End()
defer otelx.End(span, &err)

return p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
grant, err := p.GetConcreteGrant(ctx, id)
Expand All @@ -72,9 +73,9 @@ func (p *Persister) DeleteGrant(ctx context.Context, id string) error {
})
}

func (p *Persister) GetGrants(ctx context.Context, limit, offset int, optionalIssuer string) ([]trust.Grant, error) {
func (p *Persister) GetGrants(ctx context.Context, limit, offset int, optionalIssuer string) (_ []trust.Grant, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetGrants")
defer span.End()
defer otelx.End(span, &err)

grantsData := make([]trust.SQLData, 0)

Expand All @@ -97,18 +98,18 @@ func (p *Persister) GetGrants(ctx context.Context, limit, offset int, optionalIs
return grants, nil
}

func (p *Persister) CountGrants(ctx context.Context) (int, error) {
func (p *Persister) CountGrants(ctx context.Context) (n int, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.CountGrants")
defer span.End()
defer otelx.End(span, &err)

n, err := p.QueryWithNetwork(ctx).
n, err = p.QueryWithNetwork(ctx).
Count(&trust.SQLData{})
return n, sqlcon.HandleError(err)
}

func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (*jose.JSONWebKey, error) {
func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject string, keyId string) (_ *jose.JSONWebKey, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKey")
defer span.End()
defer otelx.End(span, &err)

var data trust.SQLData
query := p.QueryWithNetwork(ctx).
Expand All @@ -128,9 +129,9 @@ func (p *Persister) GetPublicKey(ctx context.Context, issuer string, subject str
return &keySet.Keys[0], nil
}

func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject string) (*jose.JSONWebKeySet, error) {
func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject string) (_ *jose.JSONWebKeySet, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeys")
defer span.End()
defer otelx.End(span, &err)

grantsData := make([]trust.SQLData, 0)
query := p.QueryWithNetwork(ctx).
Expand Down Expand Up @@ -163,9 +164,9 @@ func (p *Persister) GetPublicKeys(ctx context.Context, issuer string, subject st
return filteredKeySet, nil
}

func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) ([]string, error) {
func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subject string, keyId string) (_ []string, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.GetPublicKeyScopes")
defer span.End()
defer otelx.End(span, &err)

var data trust.SQLData
query := p.QueryWithNetwork(ctx).
Expand All @@ -181,21 +182,21 @@ func (p *Persister) GetPublicKeyScopes(ctx context.Context, issuer string, subje
return p.jwtGrantFromSQlData(data).Scope, nil
}

func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (bool, error) {
func (p *Persister) IsJWTUsed(ctx context.Context, jti string) (ok bool, err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.IsJWTUsed")
defer span.End()
defer otelx.End(span, &err)

err := p.ClientAssertionJWTValid(ctx, jti)
err = p.ClientAssertionJWTValid(ctx, jti)
if err != nil {
return true, nil
}

return false, nil
}

func (p *Persister) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) error {
func (p *Persister) MarkJWTUsedForTime(ctx context.Context, jti string, exp time.Time) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.MarkJWTUsedForTime")
defer span.End()
defer otelx.End(span, &err)

return p.SetClientAssertionJWT(ctx, jti, exp)
}
Expand Down Expand Up @@ -230,9 +231,9 @@ func (p *Persister) jwtGrantFromSQlData(data trust.SQLData) trust.Grant {
}
}

func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, limit int, batchSize int) error {
func (p *Persister) FlushInactiveGrants(ctx context.Context, notAfter time.Time, _ int, _ int) (err error) {
ctx, span := p.r.Tracer(ctx).Tracer().Start(ctx, "persistence.sql.FlushInactiveGrants")
defer span.End()
defer otelx.End(span, &err)

deleteUntil := time.Now().UTC()
if deleteUntil.After(notAfter) {
Expand Down

0 comments on commit 26f2d34

Please sign in to comment.