-
-
Notifications
You must be signed in to change notification settings - Fork 346
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: return bad request on DELETE body (#1219)
- Loading branch information
Showing
6 changed files
with
313 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright © 2023 Ory Corp | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package validate | ||
|
||
import ( | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"strings" | ||
|
||
"github.com/ory/herodot" | ||
) | ||
|
||
type Validator func(r *http.Request) (ok bool, reason string) | ||
|
||
// All runs all validators and returns an error if any of them fail. It returns | ||
// a ErrBadRequest with all failed validation messages. | ||
func All(r *http.Request, validator ...Validator) error { | ||
var reasons []string | ||
for _, v := range validator { | ||
if ok, msg := v(r); !ok { | ||
reasons = append(reasons, msg) | ||
} | ||
} | ||
if len(reasons) > 0 { | ||
return herodot.ErrBadRequest.WithReason(strings.Join(reasons, "; ")) | ||
} | ||
return nil | ||
} | ||
|
||
// NoExtraQueryParams returns a validator that checks if the request has any | ||
// query parameters that are not in the except list. | ||
func NoExtraQueryParams(except ...string) Validator { | ||
return func(req *http.Request) (ok bool, reason string) { | ||
allowed := make(map[string]struct{}, len(except)) | ||
for _, e := range except { | ||
allowed[e] = struct{}{} | ||
} | ||
for key := range req.URL.Query() { | ||
if _, found := allowed[key]; !found { | ||
return false, fmt.Sprintf("query parameter key %q unknown", key) | ||
} | ||
} | ||
return true, "" | ||
} | ||
} | ||
|
||
// QueryParamsContainsOneOf returns a validator that checks if the request has | ||
// at least one of the specified query parameters. | ||
func QueryParamsContainsOneOf(keys ...string) Validator { | ||
return func(req *http.Request) (ok bool, reason string) { | ||
oneOfKeys := make(map[string]struct{}, len(keys)) | ||
for _, k := range keys { | ||
oneOfKeys[k] = struct{}{} | ||
} | ||
for key := range req.URL.Query() { | ||
if _, found := oneOfKeys[key]; found { | ||
return true, "" | ||
} | ||
} | ||
return false, fmt.Sprintf("quey parameters must specify at least one of the following: %s", strings.Join(keys, ", ")) | ||
} | ||
} | ||
|
||
// HasEmptyBody returns a validator that checks if the request body is empty. | ||
func HasEmptyBody() Validator { | ||
return func(r *http.Request) (ok bool, reason string) { | ||
_, err := r.Body.Read([]byte{}) | ||
if err != io.EOF { | ||
return false, "body is not empty" | ||
} | ||
return true, "" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
// Copyright © 2023 Ory Corp | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package validate_test | ||
|
||
import ( | ||
"io" | ||
"net/http" | ||
"net/url" | ||
"strings" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
|
||
"github.com/ory/keto/internal/x/validate" | ||
) | ||
|
||
func toURL(t *testing.T, s string) *url.URL { | ||
u, err := url.Parse(s) | ||
require.NoError(t, err) | ||
return u | ||
} | ||
|
||
func TestValidateNoExtraParams(t *testing.T) { | ||
for _, tt := range []struct { | ||
name string | ||
req *http.Request | ||
assertErr assert.ErrorAssertionFunc | ||
}{ | ||
{ | ||
name: "empty", | ||
req: &http.Request{URL: toURL(t, "https://example.com")}, | ||
assertErr: assert.NoError, | ||
}, | ||
{ | ||
name: "all params", | ||
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=baz")}, | ||
assertErr: assert.NoError, | ||
}, | ||
{ | ||
name: "extra params", | ||
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=2&baz=3")}, | ||
assertErr: assert.Error, | ||
}, | ||
} { | ||
t.Run("case="+tt.name, func(t *testing.T) { | ||
err := validate.All(tt.req, validate.NoExtraQueryParams("foo", "bar")) | ||
tt.assertErr(t, err) | ||
}) | ||
} | ||
} | ||
|
||
func TestQueryParamsContainsOneOf(t *testing.T) { | ||
for _, tt := range []struct { | ||
name string | ||
req *http.Request | ||
assertErr assert.ErrorAssertionFunc | ||
}{ | ||
{ | ||
name: "empty", | ||
req: &http.Request{URL: toURL(t, "https://example.com")}, | ||
assertErr: assert.Error, | ||
}, | ||
{ | ||
name: "other", | ||
req: &http.Request{URL: toURL(t, "https://example.com?a=1&b=2&c=3")}, | ||
assertErr: assert.Error, | ||
}, | ||
{ | ||
name: "one", | ||
req: &http.Request{URL: toURL(t, "https://example.com?foo=1")}, | ||
assertErr: assert.NoError, | ||
}, | ||
{ | ||
name: "all params", | ||
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=baz")}, | ||
assertErr: assert.NoError, | ||
}, | ||
{ | ||
name: "extra params", | ||
req: &http.Request{URL: toURL(t, "https://example.com?foo=1&bar=2&baz=3")}, | ||
assertErr: assert.NoError, | ||
}, | ||
} { | ||
t.Run("case="+tt.name, func(t *testing.T) { | ||
err := validate.All(tt.req, validate.QueryParamsContainsOneOf("foo", "bar")) | ||
tt.assertErr(t, err) | ||
}) | ||
} | ||
} | ||
|
||
func TestValidateHasEmptyBody(t *testing.T) { | ||
for _, tt := range []struct { | ||
name string | ||
req *http.Request | ||
assertErr assert.ErrorAssertionFunc | ||
}{ | ||
{ | ||
name: "empty body", | ||
req: &http.Request{Body: io.NopCloser(strings.NewReader(""))}, | ||
assertErr: assert.NoError, | ||
}, | ||
{ | ||
name: "non-empty body", | ||
req: &http.Request{Body: io.NopCloser(strings.NewReader("content"))}, | ||
assertErr: assert.Error, | ||
}, | ||
} { | ||
t.Run("case="+tt.name, func(t *testing.T) { | ||
err := validate.All(tt.req, validate.HasEmptyBody()) | ||
tt.assertErr(t, err) | ||
}) | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.