Skip to content

Commit

Permalink
Set allowed headers via API instead of defaulting to wildcard. (#3023)
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Salvo authored and jefferai committed Aug 7, 2017
1 parent 8726b2c commit b837a1f
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 24 deletions.
13 changes: 3 additions & 10 deletions http/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,6 @@ import (
"github.com/hashicorp/vault/vault"
)

var preflightHeaders = map[string]string{
"Access-Control-Allow-Headers": "*",
"Access-Control-Max-Age": "300",
}

var allowedMethods = []string{
http.MethodDelete,
http.MethodGet,
Expand All @@ -38,8 +33,7 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler {
return
}

// Return a 403 if the origin is not
// allowed to make cross-origin requests.
// Return a 403 if the origin is not allowed to make cross-origin requests.
if !corsConf.IsValidOrigin(origin) {
respondError(w, http.StatusForbidden, fmt.Errorf("origin not allowed"))
return
Expand All @@ -56,10 +50,9 @@ func wrapCORSHandler(h http.Handler, core *vault.Core) http.Handler {
// apply headers for preflight requests
if req.Method == http.MethodOptions {
w.Header().Set("Access-Control-Allow-Methods", strings.Join(allowedMethods, ","))
w.Header().Set("Access-Control-Allow-Headers", strings.Join(corsConf.AllowedHeaders, ","))
w.Header().Set("Access-Control-Max-Age", "300")

for k, v := range preflightHeaders {
w.Header().Set(k, v)
}
return
}

Expand Down
5 changes: 3 additions & 2 deletions http/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"

"github.com/hashicorp/go-cleanhttp"
Expand All @@ -21,7 +22,7 @@ func TestHandler_cors(t *testing.T) {

// Enable CORS and allow from any origin for testing.
corsConfig := core.CORSConfig()
err := corsConfig.Enable([]string{addr})
err := corsConfig.Enable([]string{addr}, nil)
if err != nil {
t.Fatalf("Error enabling CORS: %s", err)
}
Expand Down Expand Up @@ -78,7 +79,7 @@ func TestHandler_cors(t *testing.T) {
//
expHeaders := map[string]string{
"Access-Control-Allow-Origin": addr,
"Access-Control-Allow-Headers": "*",
"Access-Control-Allow-Headers": strings.Join(stdAllowedHeaders, ","),
"Access-Control-Max-Age": "300",
"Vary": "Origin",
}
Expand Down
78 changes: 78 additions & 0 deletions http/sys_config_cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package http

import (
"encoding/json"
"net/http"
"reflect"
"testing"

"github.com/hashicorp/vault/vault"
)

func TestSysConfigCors(t *testing.T) {
var resp *http.Response

core, _, token := vault.TestCoreUnsealed(t)
ln, addr := TestServer(t, core)
defer ln.Close()
TestServerAuth(t, addr, token)

corsConf := core.CORSConfig()

// Try to enable CORS without providing a value for allowed_origins
resp = testHttpPut(t, token, addr+"/v1/sys/config/cors", map[string]interface{}{
"allowed_headers": "X-Custom-Header",
})

testResponseStatus(t, resp, 500)

// Enable CORS, but provide an origin this time.
resp = testHttpPut(t, token, addr+"/v1/sys/config/cors", map[string]interface{}{
"allowed_origins": addr,
"allowed_headers": "X-Custom-Header",
})

testResponseStatus(t, resp, 204)

// Read the CORS configuration
resp = testHttpGet(t, token, addr+"/v1/sys/config/cors")
testResponseStatus(t, resp, 200)

var actual map[string]interface{}
var expected map[string]interface{}

lenStdHeaders := len(corsConf.AllowedHeaders)

expectedHeaders := make([]interface{}, lenStdHeaders)

for i := range corsConf.AllowedHeaders {
expectedHeaders[i] = corsConf.AllowedHeaders[i]
}

expected = map[string]interface{}{
"lease_id": "",
"renewable": false,
"lease_duration": json.Number("0"),
"wrap_info": nil,
"warnings": nil,
"auth": nil,
"data": map[string]interface{}{
"enabled": true,
"allowed_origins": []interface{}{addr},
"allowed_headers": expectedHeaders,
},
"enabled": true,
"allowed_origins": []interface{}{addr},
"allowed_headers": expectedHeaders,
}

testResponseStatus(t, resp, 200)

testResponseBody(t, resp, &actual)
expected["request_id"] = actual["request_id"]

if !reflect.DeepEqual(actual, expected) {
t.Fatalf("bad: expected: %#v\nactual: %#v", expected, actual)
}

}
2 changes: 1 addition & 1 deletion vault/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,8 @@ func NewCore(conf *CoreConfig) (*Core, error) {
enableMlock: !conf.DisableMlock,
}

// Load CORS config and provide core
c.corsConfig = &CORSConfig{core: c}
// Load CORS config and provide a value for the core field.

// Wrap the physical backend in a cache layer if enabled and not already wrapped
if _, isCache := conf.Physical.(*physical.Cache); !conf.DisableCache && !isCache {
Expand Down
34 changes: 30 additions & 4 deletions vault/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,24 @@ const (
CORSEnabled
)

var stdAllowedHeaders = []string{
"Content-Type",
"X-Requested-With",
"X-Vault-AWS-IAM-Server-ID",
"X-Vault-MFA",
"X-Vault-No-Request-Forwarding",
"X-Vault-Token",
"X-Vault-Wrap-Format",
"X-Vault-Wrap-TTL",
}

// CORSConfig stores the state of the CORS configuration.
type CORSConfig struct {
sync.RWMutex `json:"-"`
core *Core
Enabled uint32 `json:"enabled"`
AllowedOrigins []string `json:"allowed_origins,omitempty"`
AllowedHeaders []string `json:"allowed_headers,omitempty"`
}

func (c *Core) saveCORSConfig() error {
Expand All @@ -31,6 +43,7 @@ func (c *Core) saveCORSConfig() error {
}
c.corsConfig.RLock()
localConfig.AllowedOrigins = c.corsConfig.AllowedOrigins
localConfig.AllowedHeaders = c.corsConfig.AllowedHeaders
c.corsConfig.RUnlock()

entry, err := logical.StorageEntryJSON("cors", localConfig)
Expand Down Expand Up @@ -72,9 +85,9 @@ func (c *Core) loadCORSConfig() error {

// Enable takes either a '*' or a comma-seprated list of URLs that can make
// cross-origin requests to Vault.
func (c *CORSConfig) Enable(urls []string) error {
func (c *CORSConfig) Enable(urls []string, headers []string) error {
if len(urls) == 0 {
return errors.New("the list of allowed origins cannot be empty")
return errors.New("at least one origin or the wildcard must be provided.")
}

if strutil.StrListContains(urls, "*") && len(urls) > 1 {
Expand All @@ -83,6 +96,15 @@ func (c *CORSConfig) Enable(urls []string) error {

c.Lock()
c.AllowedOrigins = urls

// Start with the standard headers to Vault accepts.
c.AllowedHeaders = append(c.AllowedHeaders, stdAllowedHeaders...)

// Allow the user to add additional headers to the list of
// headers allowed on cross-origin requests.
if len(headers) > 0 {
c.AllowedHeaders = append(c.AllowedHeaders, headers...)
}
c.Unlock()

atomic.StoreUint32(&c.Enabled, CORSEnabled)
Expand All @@ -95,12 +117,16 @@ func (c *CORSConfig) IsEnabled() bool {
return atomic.LoadUint32(&c.Enabled) == CORSEnabled
}

// Disable sets CORS to disabled and clears the allowed origins
// Disable sets CORS to disabled and clears the allowed origins & headers.
func (c *CORSConfig) Disable() error {
atomic.StoreUint32(&c.Enabled, CORSDisabled)
c.Lock()
c.AllowedOrigins = []string(nil)

c.AllowedOrigins = nil
c.AllowedHeaders = nil

c.Unlock()

return c.core.saveCORSConfig()
}

Expand Down
12 changes: 9 additions & 3 deletions vault/logical_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ func NewSystemBackend(core *Core) *SystemBackend {
Type: framework.TypeCommaStringSlice,
Description: "A comma-separated string or array of strings indicating origins that may make cross-origin requests.",
},
"allowed_headers": &framework.FieldSchema{
Type: framework.TypeCommaStringSlice,
Description: "A comma-separated string or array of strings indicating headers that are allowed on cross-origin requests.",
},
},

Callbacks: map[logical.Operation]framework.OperationFunc{
Expand Down Expand Up @@ -854,6 +858,7 @@ func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldD
if enabled {
corsConf.RLock()
resp.Data["allowed_origins"] = corsConf.AllowedOrigins
resp.Data["allowed_headers"] = corsConf.AllowedHeaders
corsConf.RUnlock()
}

Expand All @@ -864,12 +869,13 @@ func (b *SystemBackend) handleCORSRead(req *logical.Request, d *framework.FieldD
// cross-origin requests and sets the CORS enabled flag to true
func (b *SystemBackend) handleCORSUpdate(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
origins := d.Get("allowed_origins").([]string)
headers := d.Get("allowed_headers").([]string)

return nil, b.Core.corsConfig.Enable(origins)
return nil, b.Core.corsConfig.Enable(origins, headers)
}

// handleCORSDelete clears the allowed origins and sets the CORS enabled flag
// to false
// handleCORSDelete sets the CORS enabled flag to false and clears the list of
// allowed origins & headers.
func (b *SystemBackend) handleCORSDelete(req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
return nil, b.Core.corsConfig.Disable()
}
Expand Down
2 changes: 2 additions & 0 deletions vault/logical_system_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func TestSystemConfigCORS(t *testing.T) {

req := logical.TestRequest(t, logical.UpdateOperation, "config/cors")
req.Data["allowed_origins"] = "http://www.example.com"
req.Data["allowed_headers"] = "X-Custom-Header"
_, err := b.HandleRequest(req)
if err != nil {
t.Fatal(err)
Expand All @@ -65,6 +66,7 @@ func TestSystemConfigCORS(t *testing.T) {
Data: map[string]interface{}{
"enabled": true,
"allowed_origins": []string{"http://www.example.com"},
"allowed_headers": append(stdAllowedHeaders, "X-Custom-Header"),
},
}

Expand Down
20 changes: 16 additions & 4 deletions website/source/api/system/config-cors.html.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,40 @@ $ curl \
```json
{
"enabled": true,
"allowed_origins": "http://www.example.com"
"allowed_origins": ["http://www.example.com"],
"allowed_headers": [
"Content-Type",
"X-Requested-With",
"X-Vault-AWS-IAM-Server-ID",
"X-Vault-No-Request-Forwarding",
"X-Vault-Token",
"X-Vault-Wrap-Format",
"X-Vault-Wrap-TTL",
]
}
```

## Configure CORS Settings

This endpoint allows configuring the origins that are permitted to make
cross-origin requests.
cross-origin requests, as well as headers that are allowed on cross-origin requests.

| Method | Path | Produces |
| :------- | :--------------------------- | :--------------------- |
| `PUT` | `/sys/config/cors` | `204 (empty body)` |

### Parameters

- `allowed_origins` `(string or string array: "" or [])` – A wildcard (`*`), comma-delimited string, or array of strings specifying the origins that are permitted to make cross-origin requests.
- `allowed_origins` `(string or string array: <required>)` – A wildcard (`*`), comma-delimited string, or array of strings specifying the origins that are permitted to make cross-origin requests.

- `allowed_headers` `(string or string array: "" or [])` – A comma-delimited string or array of strings specifying headers that are permitted to be on cross-origin requests. Headers set via this parameter will be appended to the list of headers that Vault allows by default.

### Sample Payload

```json
{
"allowed_origins": "*"
"allowed_origins": "*",
"allowed_headers": "X-Custom-Header"
}
```

Expand Down

0 comments on commit b837a1f

Please sign in to comment.