Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check that all required fields in Transit API are present. #14074

Merged
merged 4 commits into from
Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion builtin/logical/transit/path_decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (b *backend) pathDecryptWrite(ctx context.Context, req *logical.Request, d
var batchInputItems []BatchRequestItem
var err error
if batchInputRaw != nil {
err = decodeBatchRequestItems(batchInputRaw, &batchInputItems)
err = decodeDecryptBatchRequestItems(batchInputRaw, &batchInputItems)
if err != nil {
return nil, fmt.Errorf("failed to parse batch input: %w", err)
}
Expand Down
26 changes: 21 additions & 5 deletions builtin/logical/transit/path_encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,19 @@ to the min_encryption_version configured on the key.`,
}
}

func decodeEncryptBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
return decodeBatchRequestItems(src, true, false, dst)
}

func decodeDecryptBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
return decodeBatchRequestItems(src, false, true, dst)
}

// decodeBatchRequestItems is a fast path alternative to mapstructure.Decode to decode []BatchRequestItem.
// It aims to behave as closely possible to the original mapstructure.Decode and will return the same errors.
// Note, however, that an error will also be returned if one of the required fields is missing.
// https://github.com/hashicorp/vault/pull/8775/files#r437709722
func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
func decodeBatchRequestItems(src interface{}, requirePlaintext bool, requireCiphertext bool, dst *[]BatchRequestItem) error {
if src == nil || dst == nil {
return nil
}
Expand Down Expand Up @@ -173,15 +182,18 @@ func decodeBatchRequestItems(src interface{}, dst *[]BatchRequestItem) error {
} else {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].ciphertext' expected type 'string', got unconvertible type '%T'", i, item["ciphertext"]))
}
} else if requireCiphertext {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].ciphertext' missing ciphertext to decrypt", i))
}

// don't allow "null" to be passed in for the plaintext value
if v, has := item["plaintext"]; has {
if casted, ok := v.(string); ok {
(*dst)[i].Plaintext = casted
} else {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].plaintext' expected type 'string', got unconvertible type '%T'", i, item["plaintext"]))
}
} else if requirePlaintext {
errs.Errors = append(errs.Errors, fmt.Sprintf("'[%d].plaintext' missing plaintext to encrypt", i))
}

if v, has := item["nonce"]; has {
Expand Down Expand Up @@ -240,7 +252,7 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
batchInputRaw := d.Raw["batch_input"]
var batchInputItems []BatchRequestItem
if batchInputRaw != nil {
err = decodeBatchRequestItems(batchInputRaw, &batchInputItems)
err = decodeEncryptBatchRequestItems(batchInputRaw, &batchInputItems)
if err != nil {
return nil, fmt.Errorf("failed to parse batch input: %w", err)
}
Expand All @@ -249,14 +261,18 @@ func (b *backend) pathEncryptWrite(ctx context.Context, req *logical.Request, d
return logical.ErrorResponse("missing batch input to process"), logical.ErrInvalidRequest
}
} else {
valueRaw, ok := d.GetOk("plaintext")
valueRaw, ok := d.Raw["plaintext"]
if !ok {
return logical.ErrorResponse("missing plaintext to encrypt"), logical.ErrInvalidRequest
}
plaintext, ok := valueRaw.(string)
if !ok {
return logical.ErrorResponse("expected plaintext of type 'string', got unconvertible type '%T'", valueRaw), logical.ErrInvalidRequest
}

batchInputItems = make([]BatchRequestItem, 1)
batchInputItems[0] = BatchRequestItem{
Plaintext: valueRaw.(string),
Plaintext: plaintext,
Context: d.Get("context").(string),
Nonce: d.Get("nonce").(string),
KeyVersion: d.Get("key_version").(int),
Expand Down
120 changes: 113 additions & 7 deletions builtin/logical/transit/path_encrypt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,75 @@ import (
"github.com/mitchellh/mapstructure"
)

func TestTransit_MissingPlaintext(t *testing.T) {
var resp *logical.Response
var err error

b, s := createBackendWithStorage(t)

// Create the policy
policyReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/existing_key",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

encData := map[string]interface{}{
"plaintext": nil,
}

encReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "encrypt/existing_key",
Storage: s,
Data: encData,
}
resp, err = b.HandleRequest(context.Background(), encReq)
if resp == nil || !resp.IsError() {
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
}
}

func TestTransit_MissingPlaintextInBatchInput(t *testing.T) {
var resp *logical.Response
var err error

b, s := createBackendWithStorage(t)

// Create the policy
policyReq := &logical.Request{
Operation: logical.UpdateOperation,
Path: "keys/existing_key",
Storage: s,
}
resp, err = b.HandleRequest(context.Background(), policyReq)
if err != nil || (resp != nil && resp.IsError()) {
t.Fatalf("err:%v resp:%#v", err, resp)
}

batchInput := []interface{}{
map[string]interface{}{}, // Note that there is no map entry for plaintext
}

batchData := map[string]interface{}{
"batch_input": batchInput,
}
batchReq := &logical.Request{
Operation: logical.CreateOperation,
Path: "encrypt/upserted_key",
Storage: s,
Data: batchData,
}
resp, err = b.HandleRequest(context.Background(), batchReq)
if err == nil {
t.Fatalf("expected error due to missing plaintext in request, err:%v resp:%#v", err, resp)
}
}

// Case1: Ensure that batch encryption did not affect the normal flow of
// encrypting the plaintext with a pre-existing key.
func TestTransit_BatchEncryptionCase1(t *testing.T) {
Expand Down Expand Up @@ -607,10 +676,12 @@ func TestTransit_BatchEncryptionCase13(t *testing.T) {
// Test that the fast path function decodeBatchRequestItems behave like mapstructure.Decode() to decode []BatchRequestItem.
func TestTransit_decodeBatchRequestItems(t *testing.T) {
tests := []struct {
name string
src interface{}
dest []BatchRequestItem
wantErrContains string
name string
src interface{}
requirePlaintext bool
requireCiphertext bool
dest []BatchRequestItem
wantErrContains string
}{
// basic edge cases of nil values
{name: "nil-nil", src: nil, dest: nil},
Expand Down Expand Up @@ -729,16 +800,51 @@ func TestTransit_decodeBatchRequestItems(t *testing.T) {
src: []interface{}{map[string]interface{}{"plaintext": "dGhlIHF1aWNrIGJyb3duIGZveA==", "nonce": "null"}},
dest: []BatchRequestItem{},
},
// required fields
{
name: "required_plaintext_present",
src: []interface{}{map[string]interface{}{"plaintext": ""}},
requirePlaintext: true,
dest: []BatchRequestItem{},
},
{
name: "required_plaintext_missing",
src: []interface{}{map[string]interface{}{}},
requirePlaintext: true,
dest: []BatchRequestItem{},
wantErrContains: "missing plaintext",
},
{
name: "required_ciphertext_present",
src: []interface{}{map[string]interface{}{"ciphertext": "dGhlIHF1aWNrIGJyb3duIGZveA=="}},
requireCiphertext: true,
dest: []BatchRequestItem{},
},
{
name: "required_ciphertext_missing",
src: []interface{}{map[string]interface{}{}},
requireCiphertext: true,
dest: []BatchRequestItem{},
wantErrContains: "missing ciphertext",
},
{
name: "required_plaintext_and_ciphertext_missing",
src: []interface{}{map[string]interface{}{}},
requirePlaintext: true,
requireCiphertext: true,
dest: []BatchRequestItem{},
wantErrContains: "missing ciphertext",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expectedDest := append(tt.dest[:0:0], tt.dest...) // copy of the dest state
expectedErr := mapstructure.Decode(tt.src, &expectedDest)
expectedErr := mapstructure.Decode(tt.src, &expectedDest) != nil || tt.wantErrContains != ""

gotErr := decodeBatchRequestItems(tt.src, &tt.dest)
gotErr := decodeBatchRequestItems(tt.src, tt.requirePlaintext, tt.requireCiphertext, &tt.dest)
gotDest := tt.dest

if expectedErr != nil {
if expectedErr {
if gotErr == nil {
t.Fatal("decodeBatchRequestItems unexpected error value; expected error but got none")
}
Expand Down
11 changes: 10 additions & 1 deletion builtin/logical/transit/path_hash.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,16 @@ Defaults to "sha2-256".`,
}

func (b *backend) pathHashWrite(ctx context.Context, req *logical.Request, d *framework.FieldData) (*logical.Response, error) {
inputB64 := d.Get("input").(string)
rawInput, ok := d.Raw["input"]
if !ok {
return logical.ErrorResponse("input missing"), logical.ErrInvalidRequest
}

inputB64, ok := rawInput.(string)
if !ok {
return logical.ErrorResponse("expected input of type 'string', got unconvertible type '%T'", rawInput), logical.ErrInvalidRequest
}

format := d.Get("format").(string)
algorithm := d.Get("urlalgorithm").(string)
if algorithm == "" {
Expand Down
6 changes: 5 additions & 1 deletion builtin/logical/transit/path_hash_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestTransit_Hash(t *testing.T) {
}
if errExpected {
if !resp.IsError() {
t.Fatalf("bad: got error response: %#v", *resp)
t.Fatalf("bad: did not get error response: %#v", *resp)
}
return
}
Expand Down Expand Up @@ -86,6 +86,10 @@ func TestTransit_Hash(t *testing.T) {
doRequest(req, false, "98rFrYMEIqVAizamCmBiBoe+GAdlo+KJW8O9vYV8nggkbIMGTU42EvDLkn8+rSCEE6uYYkv3sGF68PA/YggJdg==")

// Test bad input/format/algorithm
req.Data["input"] = nil
doRequest(req, true, "")

req.Data["input"] = "dGhlIHF1aWNrIGJyb3duIGZveA=="
req.Data["format"] = "base92"
doRequest(req, true, "")

Expand Down
7 changes: 5 additions & 2 deletions builtin/logical/transit/path_trim.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ func (b *backend) pathTrimUpdate() framework.OperationFunc {
}
defer p.Unlock()

minAvailableVersionRaw, ok := d.GetOk("min_available_version")
minAvailableVersionRaw, ok := d.Raw["min_available_version"]
if !ok {
return logical.ErrorResponse("missing min_available_version"), nil
}
minAvailableVersion := minAvailableVersionRaw.(int)
minAvailableVersion, ok := minAvailableVersionRaw.(int)
if !ok {
return logical.ErrorResponse("expected min_available_version of type 'int', got unconvertible type '%T'", minAvailableVersionRaw), logical.ErrInvalidRequest
}

originalMinAvailableVersion := p.MinAvailableVersion

Expand Down
14 changes: 14 additions & 0 deletions builtin/logical/transit/path_trim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ func TestTransit_Trim(t *testing.T) {
}
doErrReq(t, req)

// Set min_encryption_version to 0
req.Path = "keys/aes/config"
req.Data = map[string]interface{}{
"min_encryption_version": 0,
}
doReq(t, req)

// Min available version should not be converted to 0 for nil values
req.Path = "keys/aes/trim"
req.Data = map[string]interface{}{
"min_available_version": nil,
}
doErrReq(t, req)

// Set min_encryption_version to 4
req.Path = "keys/aes/config"
req.Data = map[string]interface{}{
Expand Down
3 changes: 3 additions & 0 deletions changelog/14074.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
```release-note:bug
secrets/transit: Return an error if any required parameter is missing or nil. Do not encrypt nil plaintext as if it was an empty string.
```