Skip to content

Commit

Permalink
feat: return bad request on DELETE body (#1219)
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl authored Feb 2, 2023
1 parent cbbb91e commit 195182c
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 22 deletions.
10 changes: 10 additions & 0 deletions internal/relationtuple/transact_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"encoding/json"
"net/http"

"github.com/ory/keto/internal/x/validate"
"github.com/ory/keto/ketoapi"

rts "github.com/ory/keto/proto/ory/keto/relation_tuples/v1alpha2"
Expand Down Expand Up @@ -175,6 +176,15 @@ func (h *handler) createRelation(w http.ResponseWriter, r *http.Request, _ httpr
func (h *handler) deleteRelations(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
ctx := r.Context()

if err := validate.All(r,
validate.NoExtraQueryParams(ketoapi.RelationQueryKeys...),
validate.QueryParamsContainsOneOf(ketoapi.NamespaceKey),
validate.HasEmptyBody(),
); err != nil {
h.d.Writer().WriteError(w, r, err)
return
}

q := r.URL.Query()
query, err := (&ketoapi.RelationQuery{}).FromURLQuery(q)
if err != nil {
Expand Down
90 changes: 83 additions & 7 deletions internal/relationtuple/transact_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,20 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/julienschmidt/httprouter"
"github.com/ory/x/pointerx"

"github.com/ory/keto/ketoapi"

"github.com/ory/keto/internal/driver/config"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/julienschmidt/httprouter"

"github.com/ory/keto/internal/driver"
"github.com/ory/keto/internal/driver/config"
"github.com/ory/keto/internal/namespace"
"github.com/ory/keto/internal/relationtuple"
"github.com/ory/keto/internal/x"
"github.com/ory/keto/ketoapi"
)

func TestWriteHandlers(t *testing.T) {
Expand Down Expand Up @@ -218,6 +215,85 @@ func TestWriteHandlers(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, []*relationtuple.RelationTuple{}, actualRTs)
})

t.Run("suite=bad requests", func(t *testing.T) {
nspace := addNamespace(t)

rts := []*ketoapi.RelationTuple{
{
Namespace: nspace.Name,
Object: "deleted obj",
Relation: "deleted rel",
SubjectID: pointerx.Ptr("deleted subj 1"),
},
{
Namespace: nspace.Name,
Object: "deleted obj",
Relation: "deleted rel",
SubjectID: pointerx.Ptr("deleted subj 2"),
},
}

relationtuple.MapAndWriteTuples(t, reg, rts...)

assertBadRequest := func(t *testing.T, req *http.Request) {
resp, err := ts.Client().Do(req)
require.NoError(t, err)
assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
}

assertTuplesExist := func(t *testing.T) {
mappedQuery, err := reg.Mapper().FromQuery(ctx, &ketoapi.RelationQuery{
Namespace: &nspace.Name,
})
require.NoError(t, err)

actualRTs, _, err := reg.RelationTupleManager().GetRelationTuples(ctx, mappedQuery, x.WithSize(10))
require.NoError(t, err)
mappedRTs, err := reg.Mapper().ToTuple(ctx, actualRTs...)
require.NoError(t, err)
assert.ElementsMatch(t, rts, mappedRTs)
}

t.Run("case=bad request if body sent", func(t *testing.T) {
q := url.Values{
"namespace": {nspace.Name},
"object": {"deleted obj"},
"relation": {"deleted rel"},
}
req, err := http.NewRequest(
http.MethodDelete,
ts.URL+relationtuple.WriteRouteBase+"?"+q.Encode(),
strings.NewReader("some body"))
require.NoError(t, err)

assertBadRequest(t, req)
assertTuplesExist(t)
})

t.Run("case=bad request query param misspelled", func(t *testing.T) {
req, err := http.NewRequest(
http.MethodDelete,
ts.URL+relationtuple.WriteRouteBase+"?invalid=param",
nil)
require.NoError(t, err)

assertBadRequest(t, req)
assertTuplesExist(t)
})

t.Run("case=bad request if query params misssing", func(t *testing.T) {
req, err := http.NewRequest(
http.MethodDelete,
ts.URL+relationtuple.WriteRouteBase,
nil)
require.NoError(t, err)

assertBadRequest(t, req)
assertTuplesExist(t)
})
})

})

t.Run("method=patch", func(t *testing.T) {
Expand Down
75 changes: 75 additions & 0 deletions internal/x/validate/validate.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package validate

import (
"fmt"
"io"
"net/http"
"strings"

"github.com/ory/herodot"
)

type Validator func(r *http.Request) (ok bool, reason string)

// All runs all validators and returns an error if any of them fail. It returns
// a ErrBadRequest with all failed validation messages.
func All(r *http.Request, validator ...Validator) error {
var reasons []string
for _, v := range validator {
if ok, msg := v(r); !ok {
reasons = append(reasons, msg)
}
}
if len(reasons) > 0 {
return herodot.ErrBadRequest.WithReason(strings.Join(reasons, "; "))
}
return nil
}

// NoExtraQueryParams returns a validator that checks if the request has any
// query parameters that are not in the except list.
func NoExtraQueryParams(except ...string) Validator {
return func(req *http.Request) (ok bool, reason string) {
allowed := make(map[string]struct{}, len(except))
for _, e := range except {
allowed[e] = struct{}{}
}
for key := range req.URL.Query() {
if _, found := allowed[key]; !found {
return false, fmt.Sprintf("query parameter key %q unknown", key)
}
}
return true, ""
}
}

// QueryParamsContainsOneOf returns a validator that checks if the request has
// at least one of the specified query parameters.
func QueryParamsContainsOneOf(keys ...string) Validator {
return func(req *http.Request) (ok bool, reason string) {
oneOfKeys := make(map[string]struct{}, len(keys))
for _, k := range keys {
oneOfKeys[k] = struct{}{}
}
for key := range req.URL.Query() {
if _, found := oneOfKeys[key]; found {
return true, ""
}
}
return false, fmt.Sprintf("quey parameters must specify at least one of the following: %s", strings.Join(keys, ", "))
}
}

// HasEmptyBody returns a validator that checks if the request body is empty.
func HasEmptyBody() Validator {
return func(r *http.Request) (ok bool, reason string) {
_, err := r.Body.Read([]byte{})
if err != io.EOF {
return false, "body is not empty"
}
return true, ""
}
}
116 changes: 116 additions & 0 deletions internal/x/validate/validate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package validate_test

import (
"io"
"net/http"
"net/url"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/ory/keto/internal/x/validate"
)

func toURL(t *testing.T, s string) *url.URL {
u, err := url.Parse(s)
require.NoError(t, err)
return u
}

func TestValidateNoExtraParams(t *testing.T) {
for _, tt := range []struct {
name string
req *http.Request
assertErr assert.ErrorAssertionFunc
}{
{
name: "empty",
req: &http.Request{URL: toURL(t, "https://example.com")},
assertErr: assert.NoError,
},
{
name: "all params",
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=baz")},
assertErr: assert.NoError,
},
{
name: "extra params",
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=2&baz=3")},
assertErr: assert.Error,
},
} {
t.Run("case="+tt.name, func(t *testing.T) {
err := validate.All(tt.req, validate.NoExtraQueryParams("foo", "bar"))
tt.assertErr(t, err)
})
}
}

func TestQueryParamsContainsOneOf(t *testing.T) {
for _, tt := range []struct {
name string
req *http.Request
assertErr assert.ErrorAssertionFunc
}{
{
name: "empty",
req: &http.Request{URL: toURL(t, "https://example.com")},
assertErr: assert.Error,
},
{
name: "other",
req: &http.Request{URL: toURL(t, "https://example.com?a=1&b=2&c=3")},
assertErr: assert.Error,
},
{
name: "one",
req: &http.Request{URL: toURL(t, "https://example.com?foo=1")},
assertErr: assert.NoError,
},
{
name: "all params",
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=baz")},
assertErr: assert.NoError,
},
{
name: "extra params",
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=2&baz=3")},
assertErr: assert.NoError,
},
} {
t.Run("case="+tt.name, func(t *testing.T) {
err := validate.All(tt.req, validate.QueryParamsContainsOneOf("foo", "bar"))
tt.assertErr(t, err)
})
}
}

func TestValidateHasEmptyBody(t *testing.T) {
for _, tt := range []struct {
name string
req *http.Request
assertErr assert.ErrorAssertionFunc
}{
{
name: "empty body",
req: &http.Request{Body: io.NopCloser(strings.NewReader(""))},
assertErr: assert.NoError,
},
{
name: "non-empty body",
req: &http.Request{Body: io.NopCloser(strings.NewReader("content"))},
assertErr: assert.Error,
},
} {
t.Run("case="+tt.name, func(t *testing.T) {
err := validate.All(tt.req, validate.HasEmptyBody())
tt.assertErr(t, err)
})
}

}
30 changes: 15 additions & 15 deletions ketoapi/enc_url_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ func (q *RelationQuery) FromURLQuery(query url.Values) (*RelationQuery, error) {
return nil, ErrIncompleteSubject
}

if query.Has("namespace") {
q.Namespace = pointerx.Ptr(query.Get("namespace"))
if query.Has(NamespaceKey) {
q.Namespace = pointerx.Ptr(query.Get(NamespaceKey))
}
if query.Has("object") {
q.Object = pointerx.Ptr(query.Get("object"))
if query.Has(ObjectKey) {
q.Object = pointerx.Ptr(query.Get(ObjectKey))
}
if query.Has("relation") {
q.Relation = pointerx.Ptr(query.Get("relation"))
if query.Has(RelationKey) {
q.Relation = pointerx.Ptr(query.Get(RelationKey))
}

return q, nil
Expand All @@ -57,13 +57,13 @@ func (q *RelationQuery) ToURLQuery() url.Values {
v := make(url.Values, 7)

if q.Namespace != nil {
v.Add("namespace", *q.Namespace)
v.Add(NamespaceKey, *q.Namespace)
}
if q.Relation != nil {
v.Add("relation", *q.Relation)
v.Add(RelationKey, *q.Relation)
}
if q.Object != nil {
v.Add("object", *q.Object)
v.Add(ObjectKey, *q.Object)
}
if q.SubjectID != nil {
v.Add(SubjectIDKey, *q.SubjectID)
Expand Down Expand Up @@ -112,17 +112,17 @@ func (s *SubjectSet) FromURLQuery(values url.Values) *SubjectSet {
s = &SubjectSet{}
}

s.Namespace = values.Get("namespace")
s.Relation = values.Get("relation")
s.Object = values.Get("object")
s.Namespace = values.Get(NamespaceKey)
s.Relation = values.Get(RelationKey)
s.Object = values.Get(ObjectKey)

return s
}

func (s *SubjectSet) ToURLQuery() url.Values {
return url.Values{
"namespace": []string{s.Namespace},
"object": []string{s.Object},
"relation": []string{s.Relation},
NamespaceKey: []string{s.Namespace},
ObjectKey: []string{s.Object},
RelationKey: []string{s.Relation},
}
}
Loading

0 comments on commit 195182c

Please sign in to comment.